diff --git a/include/xgboost/data.h b/include/xgboost/data.h index baf600092..28ff1ba2b 100644 --- a/include/xgboost/data.h +++ b/include/xgboost/data.h @@ -284,6 +284,7 @@ class BatchIteratorImpl { public: virtual ~BatchIteratorImpl() {} virtual BatchIteratorImpl* Clone() = 0; + virtual SparsePage& operator*() = 0; virtual const SparsePage& operator*() const = 0; virtual void operator++() = 0; virtual bool AtEnd() const = 0; @@ -307,6 +308,11 @@ class BatchIterator { ++(*impl_); } + SparsePage& operator*() { + CHECK(impl_ != nullptr); + return *(*impl_); + } + const SparsePage& operator*() const { CHECK(impl_ != nullptr); return *(*impl_); diff --git a/src/data/simple_dmatrix.cc b/src/data/simple_dmatrix.cc index c956932a9..3126cd039 100644 --- a/src/data/simple_dmatrix.cc +++ b/src/data/simple_dmatrix.cc @@ -32,6 +32,10 @@ float SimpleDMatrix::GetColDensity(size_t cidx) { class SimpleBatchIteratorImpl : public BatchIteratorImpl { public: explicit SimpleBatchIteratorImpl(SparsePage* page) : page_(page) {} + SparsePage& operator*() override { + CHECK(page_ != nullptr); + return *page_; + } const SparsePage& operator*() const override { CHECK(page_ != nullptr); return *page_; diff --git a/src/data/sparse_page_dmatrix.cc b/src/data/sparse_page_dmatrix.cc index b69d6eb52..9aad6a581 100644 --- a/src/data/sparse_page_dmatrix.cc +++ b/src/data/sparse_page_dmatrix.cc @@ -29,6 +29,7 @@ class SparseBatchIteratorImpl : public BatchIteratorImpl { explicit SparseBatchIteratorImpl(SparsePageSource* source) : source_(source) { CHECK(source_ != nullptr); } + SparsePage& operator*() override { return source_->Value(); } const SparsePage& operator*() const override { return source_->Value(); } void operator++() override { at_end_ = !source_->Next(); } bool AtEnd() const override { return at_end_; } diff --git a/src/data/sparse_page_source.cc b/src/data/sparse_page_source.cc index ee7b13a4f..34007fc73 100644 --- a/src/data/sparse_page_source.cc +++ b/src/data/sparse_page_source.cc @@ -104,6 +104,10 @@ void SparsePageSource::BeforeFirst() { } } +SparsePage& SparsePageSource::Value() { + return *page_; +} + const SparsePage& SparsePageSource::Value() const { return *page_; } diff --git a/src/data/sparse_page_source.h b/src/data/sparse_page_source.h index fbec44498..8a742d32d 100644 --- a/src/data/sparse_page_source.h +++ b/src/data/sparse_page_source.h @@ -43,6 +43,7 @@ class SparsePageSource : public DataSource { // implement BeforeFirst void BeforeFirst() override; // implement Value + SparsePage& Value(); const SparsePage& Value() const override; /*! * \brief Create source by taking data from parser. diff --git a/src/predictor/gpu_predictor.cu b/src/predictor/gpu_predictor.cu index 0fcd0270e..821014a7e 100644 --- a/src/predictor/gpu_predictor.cu +++ b/src/predictor/gpu_predictor.cu @@ -246,6 +246,7 @@ class GPUPredictor : public xgboost::Predictor { void PredictInternal (const SparsePage& batch, const MetaInfo& info, HostDeviceVector* predictions, + const size_t batch_offset, const gbm::GBTreeModel& model, const thrust::host_vector& h_tree_segments, const thrust::host_vector& h_nodes, @@ -284,8 +285,8 @@ class GPUPredictor : public xgboost::Predictor { data_distr.Devices().Index(device_)); PredictKernel<<>> - (dh::ToSpan(nodes_), predictions->DeviceSpan(device_), dh::ToSpan(tree_segments_), - dh::ToSpan(tree_group_), batch.offset.DeviceSpan(device_), + (dh::ToSpan(nodes_), predictions->DeviceSpan(device_).subspan(batch_offset), + dh::ToSpan(tree_segments_), dh::ToSpan(tree_group_), batch.offset.DeviceSpan(device_), batch.data.DeviceSpan(device_), tree_begin, tree_end, info.num_col_, num_rows, entry_start, use_shared, model.param.num_output_group); } @@ -324,18 +325,19 @@ class GPUPredictor : public xgboost::Predictor { } size_t i_batch = 0; - - for (const auto &batch : dmat->GetRowBatches()) { - CHECK_EQ(i_batch, 0) << "External memory not supported"; + size_t batch_offset = 0; + for (auto &batch : dmat->GetRowBatches()) { + CHECK(i_batch == 0 || devices_.Size() == 1) << "External memory not supported for multi-GPU"; // out_preds have been sharded and resized in InitOutPredictions() batch.offset.Shard(GPUDistribution::Overlap(devices_, 1)); std::vector device_offsets; DeviceOffsets(batch.offset, &device_offsets); - batch.data.Shard(GPUDistribution::Explicit(devices_, device_offsets)); + batch.data.Reshard(GPUDistribution::Explicit(devices_, device_offsets)); dh::ExecuteIndexShards(&shards_, [&](int idx, DeviceShard& shard) { - shard.PredictInternal(batch, dmat->Info(), out_preds, model, + shard.PredictInternal(batch, dmat->Info(), out_preds, batch_offset, model, h_tree_segments, h_nodes, tree_begin, tree_end); }); + batch_offset += batch.Size() * model.param.num_output_group; i_batch++; } monitor_.StopCuda("DevicePredictInternal"); diff --git a/tests/cpp/data/test_sparse_page_dmatrix.cc b/tests/cpp/data/test_sparse_page_dmatrix.cc index 2ea2475c2..dd842a44f 100644 --- a/tests/cpp/data/test_sparse_page_dmatrix.cc +++ b/tests/cpp/data/test_sparse_page_dmatrix.cc @@ -26,7 +26,7 @@ TEST(SparsePageDMatrix, MetaInfo) { } TEST(SparsePageDMatrix, RowAccess) { - std::unique_ptr dmat = xgboost::CreateSparsePageDMatrix(); + std::unique_ptr dmat = xgboost::CreateSparsePageDMatrix(12, 64); // Test the data read into the first row auto &batch = *dmat->GetRowBatches().begin(); diff --git a/tests/cpp/helpers.cc b/tests/cpp/helpers.cc index abe0ce776..56d31fce3 100644 --- a/tests/cpp/helpers.cc +++ b/tests/cpp/helpers.cc @@ -143,13 +143,13 @@ std::shared_ptr* CreateDMatrix(int rows, int columns, return static_cast *>(handle); } -std::unique_ptr CreateSparsePageDMatrix() { +std::unique_ptr CreateSparsePageDMatrix(size_t n_entries, size_t page_size) { // Create sufficiently large data to make two row pages dmlc::TemporaryDirectory tempdir; const std::string tmp_file = tempdir.path + "/big.libsvm"; - CreateBigTestData(tmp_file, 12); + CreateBigTestData(tmp_file, n_entries); std::unique_ptr dmat = std::unique_ptr(DMatrix::Load( - tmp_file + "#" + tmp_file + ".cache", true, false, "auto", 64UL)); + tmp_file + "#" + tmp_file + ".cache", true, false, "auto", page_size)); EXPECT_TRUE(FileExists(tmp_file + ".cache.row.page")); // Loop over the batches and count the records @@ -159,7 +159,7 @@ std::unique_ptr CreateSparsePageDMatrix() { batch_count++; row_count += batch.Size(); } - EXPECT_EQ(batch_count, 2); + EXPECT_GE(batch_count, 2); EXPECT_EQ(row_count, dmat->Info().num_row_); return dmat; diff --git a/tests/cpp/helpers.h b/tests/cpp/helpers.h index 33dd072f2..c75b5dab9 100644 --- a/tests/cpp/helpers.h +++ b/tests/cpp/helpers.h @@ -154,7 +154,7 @@ class SimpleRealUniformDistribution { std::shared_ptr *CreateDMatrix(int rows, int columns, float sparsity, int seed = 0); -std::unique_ptr CreateSparsePageDMatrix(); +std::unique_ptr CreateSparsePageDMatrix(size_t n_entries, size_t page_size); gbm::GBTreeModel CreateTestModel(); diff --git a/tests/cpp/predictor/test_cpu_predictor.cc b/tests/cpp/predictor/test_cpu_predictor.cc index 543ba3d42..449e6662d 100644 --- a/tests/cpp/predictor/test_cpu_predictor.cc +++ b/tests/cpp/predictor/test_cpu_predictor.cc @@ -55,7 +55,7 @@ TEST(cpu_predictor, Test) { } TEST(cpu_predictor, ExternalMemoryTest) { - std::unique_ptr dmat = CreateSparsePageDMatrix(); + std::unique_ptr dmat = CreateSparsePageDMatrix(12, 64); std::unique_ptr cpu_predictor = std::unique_ptr(Predictor::Create("cpu_predictor")); diff --git a/tests/cpp/predictor/test_gpu_predictor.cu b/tests/cpp/predictor/test_gpu_predictor.cu index 091e5b2c9..f992b8509 100644 --- a/tests/cpp/predictor/test_gpu_predictor.cu +++ b/tests/cpp/predictor/test_gpu_predictor.cu @@ -84,6 +84,46 @@ TEST(gpu_predictor, Test) { delete dmat; } +TEST(gpu_predictor, ExternalMemoryTest) { + std::unique_ptr gpu_predictor = + std::unique_ptr(Predictor::Create("gpu_predictor")); + gpu_predictor->Init({}, {}); + gbm::GBTreeModel model = CreateTestModel(); + std::unique_ptr dmat = CreateSparsePageDMatrix(32, 64); + + // Test predict batch + HostDeviceVector out_predictions; + gpu_predictor->PredictBatch(dmat.get(), &out_predictions, model, 0); + EXPECT_EQ(out_predictions.Size(), dmat->Info().num_row_); + for (const auto& v : out_predictions.HostVector()) { + ASSERT_EQ(v, 1.5); + } + + // Test predict leaf + std::vector leaf_out_predictions; + gpu_predictor->PredictLeaf(dmat.get(), &leaf_out_predictions, model); + EXPECT_EQ(leaf_out_predictions.size(), dmat->Info().num_row_); + for (const auto& v : leaf_out_predictions) { + ASSERT_EQ(v, 0); + } + + // Test predict contribution + std::vector out_contribution; + gpu_predictor->PredictContribution(dmat.get(), &out_contribution, model); + EXPECT_EQ(out_contribution.size(), dmat->Info().num_row_); + for (const auto& v : out_contribution) { + ASSERT_EQ(v, 1.5); + } + + // Test predict contribution (approximate method) + std::vector out_contribution_approximate; + gpu_predictor->PredictContribution(dmat.get(), &out_contribution_approximate, model, true); + EXPECT_EQ(out_contribution_approximate.size(), dmat->Info().num_row_); + for (const auto& v : out_contribution_approximate) { + ASSERT_EQ(v, 1.5); + } +} + #if defined(XGBOOST_USE_NCCL) // Test whether pickling preserves predictor parameters TEST(gpu_predictor, MGPU_PicklingTest) { @@ -195,4 +235,4 @@ TEST(gpu_predictor, MGPU_Test) { } #endif // defined(XGBOOST_USE_NCCL) } // namespace predictor -} // namespace xgboost \ No newline at end of file +} // namespace xgboost