Initial support for external memory in gpu_predictor (#4284)
This commit is contained in:
parent
54980b8959
commit
feb6ae3e18
@ -284,6 +284,7 @@ class BatchIteratorImpl {
|
|||||||
public:
|
public:
|
||||||
virtual ~BatchIteratorImpl() {}
|
virtual ~BatchIteratorImpl() {}
|
||||||
virtual BatchIteratorImpl* Clone() = 0;
|
virtual BatchIteratorImpl* Clone() = 0;
|
||||||
|
virtual SparsePage& operator*() = 0;
|
||||||
virtual const SparsePage& operator*() const = 0;
|
virtual const SparsePage& operator*() const = 0;
|
||||||
virtual void operator++() = 0;
|
virtual void operator++() = 0;
|
||||||
virtual bool AtEnd() const = 0;
|
virtual bool AtEnd() const = 0;
|
||||||
@ -307,6 +308,11 @@ class BatchIterator {
|
|||||||
++(*impl_);
|
++(*impl_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
SparsePage& operator*() {
|
||||||
|
CHECK(impl_ != nullptr);
|
||||||
|
return *(*impl_);
|
||||||
|
}
|
||||||
|
|
||||||
const SparsePage& operator*() const {
|
const SparsePage& operator*() const {
|
||||||
CHECK(impl_ != nullptr);
|
CHECK(impl_ != nullptr);
|
||||||
return *(*impl_);
|
return *(*impl_);
|
||||||
|
|||||||
@ -32,6 +32,10 @@ float SimpleDMatrix::GetColDensity(size_t cidx) {
|
|||||||
class SimpleBatchIteratorImpl : public BatchIteratorImpl {
|
class SimpleBatchIteratorImpl : public BatchIteratorImpl {
|
||||||
public:
|
public:
|
||||||
explicit SimpleBatchIteratorImpl(SparsePage* page) : page_(page) {}
|
explicit SimpleBatchIteratorImpl(SparsePage* page) : page_(page) {}
|
||||||
|
SparsePage& operator*() override {
|
||||||
|
CHECK(page_ != nullptr);
|
||||||
|
return *page_;
|
||||||
|
}
|
||||||
const SparsePage& operator*() const override {
|
const SparsePage& operator*() const override {
|
||||||
CHECK(page_ != nullptr);
|
CHECK(page_ != nullptr);
|
||||||
return *page_;
|
return *page_;
|
||||||
|
|||||||
@ -29,6 +29,7 @@ class SparseBatchIteratorImpl : public BatchIteratorImpl {
|
|||||||
explicit SparseBatchIteratorImpl(SparsePageSource* source) : source_(source) {
|
explicit SparseBatchIteratorImpl(SparsePageSource* source) : source_(source) {
|
||||||
CHECK(source_ != nullptr);
|
CHECK(source_ != nullptr);
|
||||||
}
|
}
|
||||||
|
SparsePage& operator*() override { return source_->Value(); }
|
||||||
const SparsePage& operator*() const override { return source_->Value(); }
|
const SparsePage& operator*() const override { return source_->Value(); }
|
||||||
void operator++() override { at_end_ = !source_->Next(); }
|
void operator++() override { at_end_ = !source_->Next(); }
|
||||||
bool AtEnd() const override { return at_end_; }
|
bool AtEnd() const override { return at_end_; }
|
||||||
|
|||||||
@ -104,6 +104,10 @@ void SparsePageSource::BeforeFirst() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
SparsePage& SparsePageSource::Value() {
|
||||||
|
return *page_;
|
||||||
|
}
|
||||||
|
|
||||||
const SparsePage& SparsePageSource::Value() const {
|
const SparsePage& SparsePageSource::Value() const {
|
||||||
return *page_;
|
return *page_;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -43,6 +43,7 @@ class SparsePageSource : public DataSource {
|
|||||||
// implement BeforeFirst
|
// implement BeforeFirst
|
||||||
void BeforeFirst() override;
|
void BeforeFirst() override;
|
||||||
// implement Value
|
// implement Value
|
||||||
|
SparsePage& Value();
|
||||||
const SparsePage& Value() const override;
|
const SparsePage& Value() const override;
|
||||||
/*!
|
/*!
|
||||||
* \brief Create source by taking data from parser.
|
* \brief Create source by taking data from parser.
|
||||||
|
|||||||
@ -246,6 +246,7 @@ class GPUPredictor : public xgboost::Predictor {
|
|||||||
void PredictInternal
|
void PredictInternal
|
||||||
(const SparsePage& batch, const MetaInfo& info,
|
(const SparsePage& batch, const MetaInfo& info,
|
||||||
HostDeviceVector<bst_float>* predictions,
|
HostDeviceVector<bst_float>* predictions,
|
||||||
|
const size_t batch_offset,
|
||||||
const gbm::GBTreeModel& model,
|
const gbm::GBTreeModel& model,
|
||||||
const thrust::host_vector<size_t>& h_tree_segments,
|
const thrust::host_vector<size_t>& h_tree_segments,
|
||||||
const thrust::host_vector<DevicePredictionNode>& h_nodes,
|
const thrust::host_vector<DevicePredictionNode>& h_nodes,
|
||||||
@ -284,8 +285,8 @@ class GPUPredictor : public xgboost::Predictor {
|
|||||||
data_distr.Devices().Index(device_));
|
data_distr.Devices().Index(device_));
|
||||||
|
|
||||||
PredictKernel<BLOCK_THREADS><<<GRID_SIZE, BLOCK_THREADS, shared_memory_bytes>>>
|
PredictKernel<BLOCK_THREADS><<<GRID_SIZE, BLOCK_THREADS, shared_memory_bytes>>>
|
||||||
(dh::ToSpan(nodes_), predictions->DeviceSpan(device_), dh::ToSpan(tree_segments_),
|
(dh::ToSpan(nodes_), predictions->DeviceSpan(device_).subspan(batch_offset),
|
||||||
dh::ToSpan(tree_group_), batch.offset.DeviceSpan(device_),
|
dh::ToSpan(tree_segments_), dh::ToSpan(tree_group_), batch.offset.DeviceSpan(device_),
|
||||||
batch.data.DeviceSpan(device_), tree_begin, tree_end, info.num_col_,
|
batch.data.DeviceSpan(device_), tree_begin, tree_end, info.num_col_,
|
||||||
num_rows, entry_start, use_shared, model.param.num_output_group);
|
num_rows, entry_start, use_shared, model.param.num_output_group);
|
||||||
}
|
}
|
||||||
@ -324,18 +325,19 @@ class GPUPredictor : public xgboost::Predictor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
size_t i_batch = 0;
|
size_t i_batch = 0;
|
||||||
|
size_t batch_offset = 0;
|
||||||
for (const auto &batch : dmat->GetRowBatches()) {
|
for (auto &batch : dmat->GetRowBatches()) {
|
||||||
CHECK_EQ(i_batch, 0) << "External memory not supported";
|
CHECK(i_batch == 0 || devices_.Size() == 1) << "External memory not supported for multi-GPU";
|
||||||
// out_preds have been sharded and resized in InitOutPredictions()
|
// out_preds have been sharded and resized in InitOutPredictions()
|
||||||
batch.offset.Shard(GPUDistribution::Overlap(devices_, 1));
|
batch.offset.Shard(GPUDistribution::Overlap(devices_, 1));
|
||||||
std::vector<size_t> device_offsets;
|
std::vector<size_t> device_offsets;
|
||||||
DeviceOffsets(batch.offset, &device_offsets);
|
DeviceOffsets(batch.offset, &device_offsets);
|
||||||
batch.data.Shard(GPUDistribution::Explicit(devices_, device_offsets));
|
batch.data.Reshard(GPUDistribution::Explicit(devices_, device_offsets));
|
||||||
dh::ExecuteIndexShards(&shards_, [&](int idx, DeviceShard& shard) {
|
dh::ExecuteIndexShards(&shards_, [&](int idx, DeviceShard& shard) {
|
||||||
shard.PredictInternal(batch, dmat->Info(), out_preds, model,
|
shard.PredictInternal(batch, dmat->Info(), out_preds, batch_offset, model,
|
||||||
h_tree_segments, h_nodes, tree_begin, tree_end);
|
h_tree_segments, h_nodes, tree_begin, tree_end);
|
||||||
});
|
});
|
||||||
|
batch_offset += batch.Size() * model.param.num_output_group;
|
||||||
i_batch++;
|
i_batch++;
|
||||||
}
|
}
|
||||||
monitor_.StopCuda("DevicePredictInternal");
|
monitor_.StopCuda("DevicePredictInternal");
|
||||||
|
|||||||
@ -26,7 +26,7 @@ TEST(SparsePageDMatrix, MetaInfo) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(SparsePageDMatrix, RowAccess) {
|
TEST(SparsePageDMatrix, RowAccess) {
|
||||||
std::unique_ptr<xgboost::DMatrix> dmat = xgboost::CreateSparsePageDMatrix();
|
std::unique_ptr<xgboost::DMatrix> dmat = xgboost::CreateSparsePageDMatrix(12, 64);
|
||||||
|
|
||||||
// Test the data read into the first row
|
// Test the data read into the first row
|
||||||
auto &batch = *dmat->GetRowBatches().begin();
|
auto &batch = *dmat->GetRowBatches().begin();
|
||||||
|
|||||||
@ -143,13 +143,13 @@ std::shared_ptr<xgboost::DMatrix>* CreateDMatrix(int rows, int columns,
|
|||||||
return static_cast<std::shared_ptr<xgboost::DMatrix> *>(handle);
|
return static_cast<std::shared_ptr<xgboost::DMatrix> *>(handle);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<DMatrix> CreateSparsePageDMatrix() {
|
std::unique_ptr<DMatrix> CreateSparsePageDMatrix(size_t n_entries, size_t page_size) {
|
||||||
// Create sufficiently large data to make two row pages
|
// Create sufficiently large data to make two row pages
|
||||||
dmlc::TemporaryDirectory tempdir;
|
dmlc::TemporaryDirectory tempdir;
|
||||||
const std::string tmp_file = tempdir.path + "/big.libsvm";
|
const std::string tmp_file = tempdir.path + "/big.libsvm";
|
||||||
CreateBigTestData(tmp_file, 12);
|
CreateBigTestData(tmp_file, n_entries);
|
||||||
std::unique_ptr<DMatrix> dmat = std::unique_ptr<DMatrix>(DMatrix::Load(
|
std::unique_ptr<DMatrix> dmat = std::unique_ptr<DMatrix>(DMatrix::Load(
|
||||||
tmp_file + "#" + tmp_file + ".cache", true, false, "auto", 64UL));
|
tmp_file + "#" + tmp_file + ".cache", true, false, "auto", page_size));
|
||||||
EXPECT_TRUE(FileExists(tmp_file + ".cache.row.page"));
|
EXPECT_TRUE(FileExists(tmp_file + ".cache.row.page"));
|
||||||
|
|
||||||
// Loop over the batches and count the records
|
// Loop over the batches and count the records
|
||||||
@ -159,7 +159,7 @@ std::unique_ptr<DMatrix> CreateSparsePageDMatrix() {
|
|||||||
batch_count++;
|
batch_count++;
|
||||||
row_count += batch.Size();
|
row_count += batch.Size();
|
||||||
}
|
}
|
||||||
EXPECT_EQ(batch_count, 2);
|
EXPECT_GE(batch_count, 2);
|
||||||
EXPECT_EQ(row_count, dmat->Info().num_row_);
|
EXPECT_EQ(row_count, dmat->Info().num_row_);
|
||||||
|
|
||||||
return dmat;
|
return dmat;
|
||||||
|
|||||||
@ -154,7 +154,7 @@ class SimpleRealUniformDistribution {
|
|||||||
std::shared_ptr<xgboost::DMatrix> *CreateDMatrix(int rows, int columns,
|
std::shared_ptr<xgboost::DMatrix> *CreateDMatrix(int rows, int columns,
|
||||||
float sparsity, int seed = 0);
|
float sparsity, int seed = 0);
|
||||||
|
|
||||||
std::unique_ptr<DMatrix> CreateSparsePageDMatrix();
|
std::unique_ptr<DMatrix> CreateSparsePageDMatrix(size_t n_entries, size_t page_size);
|
||||||
|
|
||||||
gbm::GBTreeModel CreateTestModel();
|
gbm::GBTreeModel CreateTestModel();
|
||||||
|
|
||||||
|
|||||||
@ -55,7 +55,7 @@ TEST(cpu_predictor, Test) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(cpu_predictor, ExternalMemoryTest) {
|
TEST(cpu_predictor, ExternalMemoryTest) {
|
||||||
std::unique_ptr<DMatrix> dmat = CreateSparsePageDMatrix();
|
std::unique_ptr<DMatrix> dmat = CreateSparsePageDMatrix(12, 64);
|
||||||
|
|
||||||
std::unique_ptr<Predictor> cpu_predictor =
|
std::unique_ptr<Predictor> cpu_predictor =
|
||||||
std::unique_ptr<Predictor>(Predictor::Create("cpu_predictor"));
|
std::unique_ptr<Predictor>(Predictor::Create("cpu_predictor"));
|
||||||
|
|||||||
@ -84,6 +84,46 @@ TEST(gpu_predictor, Test) {
|
|||||||
delete dmat;
|
delete dmat;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(gpu_predictor, ExternalMemoryTest) {
|
||||||
|
std::unique_ptr<Predictor> gpu_predictor =
|
||||||
|
std::unique_ptr<Predictor>(Predictor::Create("gpu_predictor"));
|
||||||
|
gpu_predictor->Init({}, {});
|
||||||
|
gbm::GBTreeModel model = CreateTestModel();
|
||||||
|
std::unique_ptr<DMatrix> dmat = CreateSparsePageDMatrix(32, 64);
|
||||||
|
|
||||||
|
// Test predict batch
|
||||||
|
HostDeviceVector<float> out_predictions;
|
||||||
|
gpu_predictor->PredictBatch(dmat.get(), &out_predictions, model, 0);
|
||||||
|
EXPECT_EQ(out_predictions.Size(), dmat->Info().num_row_);
|
||||||
|
for (const auto& v : out_predictions.HostVector()) {
|
||||||
|
ASSERT_EQ(v, 1.5);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test predict leaf
|
||||||
|
std::vector<float> leaf_out_predictions;
|
||||||
|
gpu_predictor->PredictLeaf(dmat.get(), &leaf_out_predictions, model);
|
||||||
|
EXPECT_EQ(leaf_out_predictions.size(), dmat->Info().num_row_);
|
||||||
|
for (const auto& v : leaf_out_predictions) {
|
||||||
|
ASSERT_EQ(v, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test predict contribution
|
||||||
|
std::vector<float> out_contribution;
|
||||||
|
gpu_predictor->PredictContribution(dmat.get(), &out_contribution, model);
|
||||||
|
EXPECT_EQ(out_contribution.size(), dmat->Info().num_row_);
|
||||||
|
for (const auto& v : out_contribution) {
|
||||||
|
ASSERT_EQ(v, 1.5);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test predict contribution (approximate method)
|
||||||
|
std::vector<float> out_contribution_approximate;
|
||||||
|
gpu_predictor->PredictContribution(dmat.get(), &out_contribution_approximate, model, true);
|
||||||
|
EXPECT_EQ(out_contribution_approximate.size(), dmat->Info().num_row_);
|
||||||
|
for (const auto& v : out_contribution_approximate) {
|
||||||
|
ASSERT_EQ(v, 1.5);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#if defined(XGBOOST_USE_NCCL)
|
#if defined(XGBOOST_USE_NCCL)
|
||||||
// Test whether pickling preserves predictor parameters
|
// Test whether pickling preserves predictor parameters
|
||||||
TEST(gpu_predictor, MGPU_PicklingTest) {
|
TEST(gpu_predictor, MGPU_PicklingTest) {
|
||||||
@ -195,4 +235,4 @@ TEST(gpu_predictor, MGPU_Test) {
|
|||||||
}
|
}
|
||||||
#endif // defined(XGBOOST_USE_NCCL)
|
#endif // defined(XGBOOST_USE_NCCL)
|
||||||
} // namespace predictor
|
} // namespace predictor
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user