[EM] Enable access to the number of batches. (#10691)
- Expose `NumBatches` in `DMatrix`. - Small cleanup for removing legacy CUDA stream and ~force CUDA context initialization~. - Purge old external memory data generation code.
This commit is contained in:
parent
033a666900
commit
8d7fe262d9
@ -541,9 +541,12 @@ class DMatrix {
|
|||||||
[[nodiscard]] bool PageExists() const;
|
[[nodiscard]] bool PageExists() const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @return Whether the data columns single column block.
|
* @return Whether the contains a single batch.
|
||||||
|
*
|
||||||
|
* The naming is legacy.
|
||||||
*/
|
*/
|
||||||
[[nodiscard]] virtual bool SingleColBlock() const = 0;
|
[[nodiscard]] bool SingleColBlock() const { return this->NumBatches() == 1; }
|
||||||
|
[[nodiscard]] virtual std::int32_t NumBatches() const { return 1; }
|
||||||
|
|
||||||
virtual ~DMatrix();
|
virtual ~DMatrix();
|
||||||
|
|
||||||
|
|||||||
@ -486,24 +486,20 @@ class TypedDiscard : public thrust::discard_iterator<T> {
|
|||||||
} // namespace detail
|
} // namespace detail
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
using TypedDiscard =
|
using TypedDiscard = std::conditional_t<HasThrustMinorVer<12>(), detail::TypedDiscardCTK114<T>,
|
||||||
std::conditional_t<HasThrustMinorVer<12>(), detail::TypedDiscardCTK114<T>,
|
|
||||||
detail::TypedDiscard<T>>;
|
detail::TypedDiscard<T>>;
|
||||||
|
|
||||||
template <typename VectorT, typename T = typename VectorT::value_type,
|
template <typename VectorT, typename T = typename VectorT::value_type,
|
||||||
typename IndexT = typename xgboost::common::Span<T>::index_type>
|
typename IndexT = typename xgboost::common::Span<T>::index_type>
|
||||||
xgboost::common::Span<T> ToSpan(
|
xgboost::common::Span<T> ToSpan(VectorT &vec, IndexT offset = 0,
|
||||||
VectorT &vec,
|
|
||||||
IndexT offset = 0,
|
|
||||||
IndexT size = std::numeric_limits<size_t>::max()) {
|
IndexT size = std::numeric_limits<size_t>::max()) {
|
||||||
size = size == std::numeric_limits<size_t>::max() ? vec.size() : size;
|
size = size == std::numeric_limits<size_t>::max() ? vec.size() : size;
|
||||||
CHECK_LE(offset + size, vec.size());
|
CHECK_LE(offset + size, vec.size());
|
||||||
return {vec.data().get() + offset, size};
|
return {thrust::raw_pointer_cast(vec.data()) + offset, size};
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
xgboost::common::Span<T> ToSpan(thrust::device_vector<T>& vec,
|
xgboost::common::Span<T> ToSpan(thrust::device_vector<T> &vec, size_t offset, size_t size) {
|
||||||
size_t offset, size_t size) {
|
|
||||||
return ToSpan(vec, offset, size);
|
return ToSpan(vec, offset, size);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -874,13 +870,7 @@ inline void CUDAEvent::Record(CUDAStreamView stream) { // NOLINT
|
|||||||
|
|
||||||
// Changing this has effect on prediction return, where we need to pass the pointer to
|
// Changing this has effect on prediction return, where we need to pass the pointer to
|
||||||
// third-party libraries like cuPy
|
// third-party libraries like cuPy
|
||||||
inline CUDAStreamView DefaultStream() {
|
inline CUDAStreamView DefaultStream() { return CUDAStreamView{cudaStreamPerThread}; }
|
||||||
#ifdef CUDA_API_PER_THREAD_DEFAULT_STREAM
|
|
||||||
return CUDAStreamView{cudaStreamPerThread};
|
|
||||||
#else
|
|
||||||
return CUDAStreamView{cudaStreamLegacy};
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
class CUDAStream {
|
class CUDAStream {
|
||||||
cudaStream_t stream_;
|
cudaStream_t stream_;
|
||||||
|
|||||||
@ -74,6 +74,8 @@ void ExtMemQuantileDMatrix::InitFromCPU(
|
|||||||
cpu_impl::GetDataShape(ctx, proxy, *iter, missing, &ext_info);
|
cpu_impl::GetDataShape(ctx, proxy, *iter, missing, &ext_info);
|
||||||
ext_info.SetInfo(ctx, &this->info_);
|
ext_info.SetInfo(ctx, &this->info_);
|
||||||
|
|
||||||
|
this->n_batches_ = ext_info.n_batches;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Generate quantiles
|
* Generate quantiles
|
||||||
*/
|
*/
|
||||||
|
|||||||
@ -33,7 +33,7 @@ class ExtMemQuantileDMatrix : public QuantileDMatrix {
|
|||||||
std::string cache, bst_bin_t max_bin, bool on_host);
|
std::string cache, bst_bin_t max_bin, bool on_host);
|
||||||
~ExtMemQuantileDMatrix() override;
|
~ExtMemQuantileDMatrix() override;
|
||||||
|
|
||||||
[[nodiscard]] bool SingleColBlock() const override { return false; }
|
[[nodiscard]] std::int32_t NumBatches() const override { return n_batches_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void InitFromCPU(
|
void InitFromCPU(
|
||||||
@ -63,6 +63,7 @@ class ExtMemQuantileDMatrix : public QuantileDMatrix {
|
|||||||
std::string cache_prefix_;
|
std::string cache_prefix_;
|
||||||
bool on_host_;
|
bool on_host_;
|
||||||
BatchParam batch_;
|
BatchParam batch_;
|
||||||
|
bst_idx_t n_batches_{0};
|
||||||
|
|
||||||
using EllpackDiskPtr = std::shared_ptr<ExtEllpackPageSource>;
|
using EllpackDiskPtr = std::shared_ptr<ExtEllpackPageSource>;
|
||||||
using EllpackHostPtr = std::shared_ptr<ExtEllpackPageHostSource>;
|
using EllpackHostPtr = std::shared_ptr<ExtEllpackPageHostSource>;
|
||||||
|
|||||||
@ -57,8 +57,6 @@ class IterativeDMatrix : public QuantileDMatrix {
|
|||||||
|
|
||||||
BatchSet<EllpackPage> GetEllpackBatches(Context const *ctx, const BatchParam ¶m) override;
|
BatchSet<EllpackPage> GetEllpackBatches(Context const *ctx, const BatchParam ¶m) override;
|
||||||
BatchSet<ExtSparsePage> GetExtBatches(Context const *ctx, BatchParam const ¶m) override;
|
BatchSet<ExtSparsePage> GetExtBatches(Context const *ctx, BatchParam const ¶m) override;
|
||||||
|
|
||||||
bool SingleColBlock() const override { return true; }
|
|
||||||
};
|
};
|
||||||
} // namespace data
|
} // namespace data
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -94,7 +94,6 @@ class DMatrixProxy : public DMatrix {
|
|||||||
MetaInfo const& Info() const override { return info_; }
|
MetaInfo const& Info() const override { return info_; }
|
||||||
Context const* Ctx() const override { return &ctx_; }
|
Context const* Ctx() const override { return &ctx_; }
|
||||||
|
|
||||||
bool SingleColBlock() const override { return false; }
|
|
||||||
bool EllpackExists() const override { return false; }
|
bool EllpackExists() const override { return false; }
|
||||||
bool GHistIndexExists() const override { return false; }
|
bool GHistIndexExists() const override { return false; }
|
||||||
bool SparsePageExists() const override { return false; }
|
bool SparsePageExists() const override { return false; }
|
||||||
|
|||||||
@ -33,7 +33,6 @@ class SimpleDMatrix : public DMatrix {
|
|||||||
const MetaInfo& Info() const override;
|
const MetaInfo& Info() const override;
|
||||||
Context const* Ctx() const override { return &fmat_ctx_; }
|
Context const* Ctx() const override { return &fmat_ctx_; }
|
||||||
|
|
||||||
bool SingleColBlock() const override { return true; }
|
|
||||||
DMatrix* Slice(common::Span<int32_t const> ridxs) override;
|
DMatrix* Slice(common::Span<int32_t const> ridxs) override;
|
||||||
DMatrix* SliceCol(int num_slices, int slice_id) override;
|
DMatrix* SliceCol(int num_slices, int slice_id) override;
|
||||||
|
|
||||||
|
|||||||
@ -90,8 +90,7 @@ class SparsePageDMatrix : public DMatrix {
|
|||||||
[[nodiscard]] MetaInfo &Info() override;
|
[[nodiscard]] MetaInfo &Info() override;
|
||||||
[[nodiscard]] const MetaInfo &Info() const override;
|
[[nodiscard]] const MetaInfo &Info() const override;
|
||||||
[[nodiscard]] Context const *Ctx() const override { return &fmat_ctx_; }
|
[[nodiscard]] Context const *Ctx() const override { return &fmat_ctx_; }
|
||||||
// The only DMatrix implementation that returns false.
|
[[nodiscard]] std::int32_t NumBatches() const override { return n_batches_; }
|
||||||
[[nodiscard]] bool SingleColBlock() const override { return false; }
|
|
||||||
DMatrix *Slice(common::Span<std::int32_t const>) override {
|
DMatrix *Slice(common::Span<std::int32_t const>) override {
|
||||||
LOG(FATAL) << "Slicing DMatrix is not supported for external memory.";
|
LOG(FATAL) << "Slicing DMatrix is not supported for external memory.";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|||||||
@ -3,10 +3,10 @@
|
|||||||
*/
|
*/
|
||||||
#include "sparse_page_source.h"
|
#include "sparse_page_source.h"
|
||||||
|
|
||||||
#include <filesystem> // for exists
|
|
||||||
#include <string> // for string
|
|
||||||
#include <cstdio> // for remove
|
#include <cstdio> // for remove
|
||||||
|
#include <filesystem> // for exists
|
||||||
#include <numeric> // for partial_sum
|
#include <numeric> // for partial_sum
|
||||||
|
#include <string> // for string
|
||||||
|
|
||||||
namespace xgboost::data {
|
namespace xgboost::data {
|
||||||
void Cache::Commit() {
|
void Cache::Commit() {
|
||||||
@ -27,4 +27,8 @@ void TryDeleteCacheFile(const std::string& file) {
|
|||||||
<< "; you may want to remove it manually";
|
<< "; you may want to remove it manually";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if !defined(XGBOOST_USE_CUDA)
|
||||||
|
void InitNewThread::operator()() const { *GlobalConfigThreadLocalStore::Get() = config; }
|
||||||
|
#endif
|
||||||
} // namespace xgboost::data
|
} // namespace xgboost::data
|
||||||
|
|||||||
@ -18,4 +18,14 @@ void DevicePush(DMatrixProxy *proxy, float missing, SparsePage *page) {
|
|||||||
cuda_impl::Dispatch(proxy,
|
cuda_impl::Dispatch(proxy,
|
||||||
[&](auto const &value) { CopyToSparsePage(value, device, missing, page); });
|
[&](auto const &value) { CopyToSparsePage(value, device, missing, page); });
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void InitNewThread::operator()() const {
|
||||||
|
*GlobalConfigThreadLocalStore::Get() = config;
|
||||||
|
// For CUDA 12.2, we need to force initialize the CUDA context by synchronizing the
|
||||||
|
// stream when creating a new thread in the thread pool. While for CUDA 11.8, this
|
||||||
|
// action might cause an insufficient driver version error for some reason. Lastly, it
|
||||||
|
// should work with CUDA 12.5 without any action being taken.
|
||||||
|
|
||||||
|
// dh::DefaultStream().Sync();
|
||||||
|
}
|
||||||
} // namespace xgboost::data
|
} // namespace xgboost::data
|
||||||
|
|||||||
@ -210,6 +210,12 @@ class DefaultFormatPolicy {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct InitNewThread {
|
||||||
|
GlobalConfiguration config = *GlobalConfigThreadLocalStore::Get();
|
||||||
|
|
||||||
|
void operator()() const;
|
||||||
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Base class for all page sources. Handles fetching, writing, and iteration.
|
* @brief Base class for all page sources. Handles fetching, writing, and iteration.
|
||||||
*
|
*
|
||||||
@ -330,10 +336,7 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S>, public FormatStreamPol
|
|||||||
public:
|
public:
|
||||||
SparsePageSourceImpl(float missing, int nthreads, bst_feature_t n_features, bst_idx_t n_batches,
|
SparsePageSourceImpl(float missing, int nthreads, bst_feature_t n_features, bst_idx_t n_batches,
|
||||||
std::shared_ptr<Cache> cache)
|
std::shared_ptr<Cache> cache)
|
||||||
: workers_{std::max(2, std::min(nthreads, 16)),
|
: workers_{std::max(2, std::min(nthreads, 16)), InitNewThread{}},
|
||||||
[config = *GlobalConfigThreadLocalStore::Get()] {
|
|
||||||
*GlobalConfigThreadLocalStore::Get() = config;
|
|
||||||
}},
|
|
||||||
missing_{missing},
|
missing_{missing},
|
||||||
nthreads_{nthreads},
|
nthreads_{nthreads},
|
||||||
n_features_{n_features},
|
n_features_{n_features},
|
||||||
|
|||||||
@ -63,26 +63,27 @@ TEST(SparsePage, PushCSC) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(SparsePage, PushCSCAfterTranspose) {
|
TEST(SparsePage, PushCSCAfterTranspose) {
|
||||||
size_t constexpr kPageSize = 1024, kEntriesPerCol = 3;
|
bst_idx_t constexpr kRows = 1024, kCols = 21;
|
||||||
size_t constexpr kEntries = kPageSize * kEntriesPerCol * 2;
|
|
||||||
std::unique_ptr<DMatrix> dmat = CreateSparsePageDMatrix(kEntries);
|
auto dmat =
|
||||||
|
RandomDataGenerator{kRows, kCols, 0.0f}.Batches(4).GenerateSparsePageDMatrix("temp", true);
|
||||||
const int ncols = dmat->Info().num_col_;
|
const int ncols = dmat->Info().num_col_;
|
||||||
SparsePage page; // Consolidated sparse page
|
SparsePage page; // Consolidated sparse page
|
||||||
for (const auto &batch : dmat->GetBatches<xgboost::SparsePage>()) {
|
for (const auto& batch : dmat->GetBatches<xgboost::SparsePage>()) {
|
||||||
// Transpose each batch and push
|
// Transpose each batch and push
|
||||||
SparsePage tmp = batch.GetTranspose(ncols, AllThreadsForTest());
|
SparsePage tmp = batch.GetTranspose(ncols, AllThreadsForTest());
|
||||||
page.PushCSC(tmp);
|
page.PushCSC(tmp);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Make sure that the final sparse page has the right number of entries
|
// Make sure that the final sparse page has the right number of entries
|
||||||
ASSERT_EQ(kEntries, page.data.Size());
|
ASSERT_EQ(kRows * kCols, page.data.Size());
|
||||||
|
|
||||||
page.SortRows(AllThreadsForTest());
|
page.SortRows(AllThreadsForTest());
|
||||||
auto v = page.GetView();
|
auto v = page.GetView();
|
||||||
for (size_t i = 0; i < v.Size(); ++i) {
|
for (size_t i = 0; i < v.Size(); ++i) {
|
||||||
auto column = v[i];
|
auto column = v[i];
|
||||||
for (size_t j = 1; j < column.size(); ++j) {
|
for (size_t j = 1; j < column.size(); ++j) {
|
||||||
ASSERT_GE(column[j].fvalue, column[j-1].fvalue);
|
ASSERT_GE(column[j].fvalue, column[j - 1].fvalue);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -140,13 +140,11 @@ struct ReadRowFunction {
|
|||||||
TEST(EllpackPage, Copy) {
|
TEST(EllpackPage, Copy) {
|
||||||
constexpr size_t kRows = 1024;
|
constexpr size_t kRows = 1024;
|
||||||
constexpr size_t kCols = 16;
|
constexpr size_t kCols = 16;
|
||||||
constexpr size_t kPageSize = 1024;
|
|
||||||
|
|
||||||
// Create a DMatrix with multiple batches.
|
// Create a DMatrix with multiple batches.
|
||||||
dmlc::TemporaryDirectory tmpdir;
|
auto dmat =
|
||||||
std::unique_ptr<DMatrix>
|
RandomDataGenerator{kRows, kCols, 0.0f}.Batches(4).GenerateSparsePageDMatrix("temp", true);
|
||||||
dmat(CreateSparsePageDMatrixWithRC(kRows, kCols, kPageSize, true, tmpdir));
|
auto ctx = MakeCUDACtx(0);
|
||||||
Context ctx{MakeCUDACtx(0)};
|
|
||||||
auto param = BatchParam{256, tree::TrainParam::DftSparseThreshold()};
|
auto param = BatchParam{256, tree::TrainParam::DftSparseThreshold()};
|
||||||
auto page = (*dmat->GetBatches<EllpackPage>(&ctx, param).begin()).Impl();
|
auto page = (*dmat->GetBatches<EllpackPage>(&ctx, param).begin()).Impl();
|
||||||
|
|
||||||
@ -187,14 +185,12 @@ TEST(EllpackPage, Copy) {
|
|||||||
TEST(EllpackPage, Compact) {
|
TEST(EllpackPage, Compact) {
|
||||||
constexpr size_t kRows = 16;
|
constexpr size_t kRows = 16;
|
||||||
constexpr size_t kCols = 2;
|
constexpr size_t kCols = 2;
|
||||||
constexpr size_t kPageSize = 1;
|
|
||||||
constexpr size_t kCompactedRows = 8;
|
constexpr size_t kCompactedRows = 8;
|
||||||
|
|
||||||
// Create a DMatrix with multiple batches.
|
// Create a DMatrix with multiple batches.
|
||||||
dmlc::TemporaryDirectory tmpdir;
|
auto dmat =
|
||||||
std::unique_ptr<DMatrix> dmat(
|
RandomDataGenerator{kRows, kCols, 0.0f}.Batches(2).GenerateSparsePageDMatrix("temp", true);
|
||||||
CreateSparsePageDMatrixWithRC(kRows, kCols, kPageSize, true, tmpdir));
|
auto ctx = MakeCUDACtx(0);
|
||||||
Context ctx{MakeCUDACtx(0)};
|
|
||||||
auto param = BatchParam{256, tree::TrainParam::DftSparseThreshold()};
|
auto param = BatchParam{256, tree::TrainParam::DftSparseThreshold()};
|
||||||
auto page = (*dmat->GetBatches<EllpackPage>(&ctx, param).begin()).Impl();
|
auto page = (*dmat->GetBatches<EllpackPage>(&ctx, param).begin()).Impl();
|
||||||
|
|
||||||
|
|||||||
@ -214,15 +214,15 @@ TEST(SparsePageDMatrix, MetaInfo) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(SparsePageDMatrix, RowAccess) {
|
TEST(SparsePageDMatrix, RowAccess) {
|
||||||
std::unique_ptr<xgboost::DMatrix> dmat = xgboost::CreateSparsePageDMatrix(24);
|
auto dmat = RandomDataGenerator{12, 6, 0.8f}.Batches(2).GenerateSparsePageDMatrix("temp", false);
|
||||||
|
|
||||||
// Test the data read into the first row
|
// Test the data read into the first row
|
||||||
auto &batch = *dmat->GetBatches<xgboost::SparsePage>().begin();
|
auto &batch = *dmat->GetBatches<xgboost::SparsePage>().begin();
|
||||||
auto page = batch.GetView();
|
auto page = batch.GetView();
|
||||||
auto first_row = page[0];
|
auto first_row = page[0];
|
||||||
ASSERT_EQ(first_row.size(), 3ul);
|
ASSERT_EQ(first_row.size(), 1ul);
|
||||||
EXPECT_EQ(first_row[2].index, 2u);
|
EXPECT_EQ(first_row[0].index, 5u);
|
||||||
EXPECT_NEAR(first_row[2].fvalue, 0.986566, 1e-4);
|
EXPECT_NEAR(first_row[0].fvalue, 0.1805125, 1e-4);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(SparsePageDMatrix, ColAccess) {
|
TEST(SparsePageDMatrix, ColAccess) {
|
||||||
@ -268,11 +268,10 @@ TEST(SparsePageDMatrix, ColAccess) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(SparsePageDMatrix, ThreadSafetyException) {
|
TEST(SparsePageDMatrix, ThreadSafetyException) {
|
||||||
size_t constexpr kEntriesPerCol = 3;
|
|
||||||
size_t constexpr kEntries = 64 * kEntriesPerCol * 2;
|
|
||||||
Context ctx;
|
Context ctx;
|
||||||
|
|
||||||
std::unique_ptr<xgboost::DMatrix> dmat = xgboost::CreateSparsePageDMatrix(kEntries);
|
auto dmat =
|
||||||
|
RandomDataGenerator{4096, 12, 0.0f}.Batches(8).GenerateSparsePageDMatrix("temp", true);
|
||||||
|
|
||||||
int threads = 1000;
|
int threads = 1000;
|
||||||
|
|
||||||
@ -304,10 +303,9 @@ TEST(SparsePageDMatrix, ThreadSafetyException) {
|
|||||||
|
|
||||||
// Multi-batches access
|
// Multi-batches access
|
||||||
TEST(SparsePageDMatrix, ColAccessBatches) {
|
TEST(SparsePageDMatrix, ColAccessBatches) {
|
||||||
size_t constexpr kPageSize = 1024, kEntriesPerCol = 3;
|
|
||||||
size_t constexpr kEntries = kPageSize * kEntriesPerCol * 2;
|
|
||||||
// Create multiple sparse pages
|
// Create multiple sparse pages
|
||||||
std::unique_ptr<xgboost::DMatrix> dmat{xgboost::CreateSparsePageDMatrix(kEntries)};
|
auto dmat =
|
||||||
|
RandomDataGenerator{1024, 32, 0.4f}.Batches(3).GenerateSparsePageDMatrix("temp", true);
|
||||||
ASSERT_EQ(dmat->Ctx()->Threads(), AllThreadsForTest());
|
ASSERT_EQ(dmat->Ctx()->Threads(), AllThreadsForTest());
|
||||||
Context ctx;
|
Context ctx;
|
||||||
for (auto const &page : dmat->GetBatches<xgboost::CSCPage>(&ctx)) {
|
for (auto const &page : dmat->GetBatches<xgboost::CSCPage>(&ctx)) {
|
||||||
|
|||||||
@ -115,13 +115,10 @@ TEST(SparsePageDMatrix, EllpackSkipSparsePage) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(SparsePageDMatrix, MultipleEllpackPages) {
|
TEST(SparsePageDMatrix, MultipleEllpackPages) {
|
||||||
Context ctx{MakeCUDACtx(0)};
|
auto ctx = MakeCUDACtx(0);
|
||||||
auto param = BatchParam{256, tree::TrainParam::DftSparseThreshold()};
|
auto param = BatchParam{256, tree::TrainParam::DftSparseThreshold()};
|
||||||
dmlc::TemporaryDirectory tmpdir;
|
dmlc::TemporaryDirectory tmpdir;
|
||||||
std::string filename = tmpdir.path + "/big.libsvm";
|
auto dmat = RandomDataGenerator{1024, 2, 0.5f}.Batches(2).GenerateSparsePageDMatrix("temp", true);
|
||||||
size_t constexpr kPageSize = 64, kEntriesPerCol = 3;
|
|
||||||
size_t constexpr kEntries = kPageSize * kEntriesPerCol * 2;
|
|
||||||
std::unique_ptr<DMatrix> dmat = CreateSparsePageDMatrix(kEntries, filename);
|
|
||||||
|
|
||||||
// Loop over the batches and count the records
|
// Loop over the batches and count the records
|
||||||
std::int64_t batch_count = 0;
|
std::int64_t batch_count = 0;
|
||||||
@ -135,15 +132,13 @@ TEST(SparsePageDMatrix, MultipleEllpackPages) {
|
|||||||
EXPECT_EQ(row_count, dmat->Info().num_row_);
|
EXPECT_EQ(row_count, dmat->Info().num_row_);
|
||||||
|
|
||||||
auto path =
|
auto path =
|
||||||
data::MakeId(filename,
|
data::MakeId("tmep", dynamic_cast<data::SparsePageDMatrix*>(dmat.get())) + ".ellpack.page";
|
||||||
dynamic_cast<data::SparsePageDMatrix *>(dmat.get())) +
|
|
||||||
".ellpack.page";
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(SparsePageDMatrix, RetainEllpackPage) {
|
TEST(SparsePageDMatrix, RetainEllpackPage) {
|
||||||
Context ctx{MakeCUDACtx(0)};
|
auto ctx = MakeCUDACtx(0);
|
||||||
auto param = BatchParam{32, tree::TrainParam::DftSparseThreshold()};
|
auto param = BatchParam{32, tree::TrainParam::DftSparseThreshold()};
|
||||||
auto m = CreateSparsePageDMatrix(10000);
|
auto m = RandomDataGenerator{2048, 4, 0.0f}.Batches(8).GenerateSparsePageDMatrix("temp", true);
|
||||||
|
|
||||||
auto batches = m->GetBatches<EllpackPage>(&ctx, param);
|
auto batches = m->GetBatches<EllpackPage>(&ctx, param);
|
||||||
auto begin = batches.begin();
|
auto begin = batches.begin();
|
||||||
@ -278,20 +273,19 @@ struct ReadRowFunction {
|
|||||||
};
|
};
|
||||||
|
|
||||||
TEST(SparsePageDMatrix, MultipleEllpackPageContent) {
|
TEST(SparsePageDMatrix, MultipleEllpackPageContent) {
|
||||||
constexpr size_t kRows = 6;
|
constexpr size_t kRows = 16;
|
||||||
constexpr size_t kCols = 2;
|
constexpr size_t kCols = 2;
|
||||||
constexpr int kMaxBins = 256;
|
constexpr int kMaxBins = 256;
|
||||||
constexpr size_t kPageSize = 1;
|
|
||||||
|
|
||||||
// Create an in-memory DMatrix.
|
// Create an in-memory DMatrix.
|
||||||
std::unique_ptr<DMatrix> dmat(CreateSparsePageDMatrixWithRC(kRows, kCols, 0, true));
|
auto dmat =
|
||||||
|
RandomDataGenerator{kRows, kCols, 0.0f}.Batches(1).GenerateSparsePageDMatrix("temp", true);
|
||||||
|
|
||||||
// Create a DMatrix with multiple batches.
|
// Create a DMatrix with multiple batches.
|
||||||
dmlc::TemporaryDirectory tmpdir;
|
auto dmat_ext =
|
||||||
std::unique_ptr<DMatrix>
|
RandomDataGenerator{kRows, kCols, 0.0f}.Batches(2).GenerateSparsePageDMatrix("temp", true);
|
||||||
dmat_ext(CreateSparsePageDMatrixWithRC(kRows, kCols, kPageSize, true, tmpdir));
|
|
||||||
|
|
||||||
Context ctx{MakeCUDACtx(0)};
|
auto ctx = MakeCUDACtx(0);
|
||||||
auto param = BatchParam{kMaxBins, tree::TrainParam::DftSparseThreshold()};
|
auto param = BatchParam{kMaxBins, tree::TrainParam::DftSparseThreshold()};
|
||||||
auto impl = (*dmat->GetBatches<EllpackPage>(&ctx, param).begin()).Impl();
|
auto impl = (*dmat->GetBatches<EllpackPage>(&ctx, param).begin()).Impl();
|
||||||
EXPECT_EQ(impl->base_rowid, 0);
|
EXPECT_EQ(impl->base_rowid, 0);
|
||||||
@ -325,17 +319,16 @@ TEST(SparsePageDMatrix, EllpackPageMultipleLoops) {
|
|||||||
constexpr size_t kRows = 1024;
|
constexpr size_t kRows = 1024;
|
||||||
constexpr size_t kCols = 16;
|
constexpr size_t kCols = 16;
|
||||||
constexpr int kMaxBins = 256;
|
constexpr int kMaxBins = 256;
|
||||||
constexpr size_t kPageSize = 4096;
|
|
||||||
|
|
||||||
// Create an in-memory DMatrix.
|
// Create an in-memory DMatrix.
|
||||||
std::unique_ptr<DMatrix> dmat(CreateSparsePageDMatrixWithRC(kRows, kCols, 0, true));
|
auto dmat =
|
||||||
|
RandomDataGenerator{kRows, kCols, 0.0f}.Batches(1).GenerateSparsePageDMatrix("temp", true);
|
||||||
|
|
||||||
// Create a DMatrix with multiple batches.
|
// Create a DMatrix with multiple batches.
|
||||||
dmlc::TemporaryDirectory tmpdir;
|
auto dmat_ext =
|
||||||
std::unique_ptr<DMatrix>
|
RandomDataGenerator{kRows, kCols, 0.0f}.Batches(8).GenerateSparsePageDMatrix("temp", true);
|
||||||
dmat_ext(CreateSparsePageDMatrixWithRC(kRows, kCols, kPageSize, true, tmpdir));
|
|
||||||
|
|
||||||
Context ctx{MakeCUDACtx(0)};
|
auto ctx = MakeCUDACtx(0);
|
||||||
auto param = BatchParam{kMaxBins, tree::TrainParam::DftSparseThreshold()};
|
auto param = BatchParam{kMaxBins, tree::TrainParam::DftSparseThreshold()};
|
||||||
|
|
||||||
size_t current_row = 0;
|
size_t current_row = 0;
|
||||||
|
|||||||
@ -715,7 +715,7 @@ TEST(GBTree, InplacePredictionError) {
|
|||||||
p_fmat = rng.GenerateQuantileDMatrix(true);
|
p_fmat = rng.GenerateQuantileDMatrix(true);
|
||||||
} else {
|
} else {
|
||||||
#if defined(XGBOOST_USE_CUDA)
|
#if defined(XGBOOST_USE_CUDA)
|
||||||
p_fmat = rng.GenerateDeviceDMatrix(true);
|
p_fmat = rng.Device(ctx->Device()).GenerateQuantileDMatrix(true);
|
||||||
#else
|
#else
|
||||||
CHECK(p_fmat);
|
CHECK(p_fmat);
|
||||||
#endif // defined(XGBOOST_USE_CUDA)
|
#endif // defined(XGBOOST_USE_CUDA)
|
||||||
|
|||||||
@ -13,7 +13,6 @@
|
|||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <limits> // for numeric_limits
|
#include <limits> // for numeric_limits
|
||||||
#include <random>
|
|
||||||
|
|
||||||
#include "../../src/collective/communicator-inl.h" // for GetRank
|
#include "../../src/collective/communicator-inl.h" // for GetRank
|
||||||
#include "../../src/data/adapter.h"
|
#include "../../src/data/adapter.h"
|
||||||
@ -21,8 +20,6 @@
|
|||||||
#include "../../src/data/simple_dmatrix.h"
|
#include "../../src/data/simple_dmatrix.h"
|
||||||
#include "../../src/data/sparse_page_dmatrix.h"
|
#include "../../src/data/sparse_page_dmatrix.h"
|
||||||
#include "../../src/gbm/gbtree_model.h"
|
#include "../../src/gbm/gbtree_model.h"
|
||||||
#include "../../src/tree/param.h" // for TrainParam
|
|
||||||
#include "filesystem.h" // dmlc::TemporaryDirectory
|
|
||||||
#include "xgboost/c_api.h"
|
#include "xgboost/c_api.h"
|
||||||
#include "xgboost/predictor.h"
|
#include "xgboost/predictor.h"
|
||||||
|
|
||||||
@ -456,6 +453,7 @@ void RandomDataGenerator::GenerateCSR(
|
|||||||
}
|
}
|
||||||
|
|
||||||
EXPECT_EQ(batch_count, n_batches_);
|
EXPECT_EQ(batch_count, n_batches_);
|
||||||
|
EXPECT_EQ(dmat->NumBatches(), n_batches_);
|
||||||
EXPECT_EQ(row_count, dmat->Info().num_row_);
|
EXPECT_EQ(row_count, dmat->Info().num_row_);
|
||||||
|
|
||||||
if (with_label) {
|
if (with_label) {
|
||||||
@ -503,13 +501,24 @@ void RandomDataGenerator::GenerateCSR(
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<DMatrix> RandomDataGenerator::GenerateQuantileDMatrix(bool with_label) {
|
std::shared_ptr<DMatrix> RandomDataGenerator::GenerateQuantileDMatrix(bool with_label) {
|
||||||
|
std::shared_ptr<data::IterativeDMatrix> p_fmat;
|
||||||
|
|
||||||
|
if (this->device_.IsCPU()) {
|
||||||
NumpyArrayIterForTest iter{this->sparsity_, this->rows_, this->cols_, 1};
|
NumpyArrayIterForTest iter{this->sparsity_, this->rows_, this->cols_, 1};
|
||||||
auto m = std::make_shared<data::IterativeDMatrix>(
|
p_fmat =
|
||||||
&iter, iter.Proxy(), nullptr, Reset, Next, std::numeric_limits<float>::quiet_NaN(), 0, bins_);
|
std::make_shared<data::IterativeDMatrix>(&iter, iter.Proxy(), nullptr, Reset, Next,
|
||||||
if (with_label) {
|
std::numeric_limits<float>::quiet_NaN(), 0, bins_);
|
||||||
this->GenerateLabels(m);
|
} else {
|
||||||
|
CudaArrayIterForTest iter{this->sparsity_, this->rows_, this->cols_, 1};
|
||||||
|
p_fmat =
|
||||||
|
std::make_shared<data::IterativeDMatrix>(&iter, iter.Proxy(), nullptr, Reset, Next,
|
||||||
|
std::numeric_limits<float>::quiet_NaN(), 0, bins_);
|
||||||
}
|
}
|
||||||
return m;
|
|
||||||
|
if (with_label) {
|
||||||
|
this->GenerateLabels(p_fmat);
|
||||||
|
}
|
||||||
|
return p_fmat;
|
||||||
}
|
}
|
||||||
|
|
||||||
#if !defined(XGBOOST_USE_CUDA)
|
#if !defined(XGBOOST_USE_CUDA)
|
||||||
@ -551,125 +560,6 @@ std::shared_ptr<DMatrix> GetDMatrixFromData(const std::vector<float>& x, std::si
|
|||||||
return p_fmat;
|
return p_fmat;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<DMatrix> CreateSparsePageDMatrix(bst_idx_t n_samples, bst_feature_t n_features,
|
|
||||||
size_t n_batches, std::string prefix) {
|
|
||||||
CHECK_GE(n_samples, n_batches);
|
|
||||||
NumpyArrayIterForTest iter(0, n_samples, n_features, n_batches);
|
|
||||||
|
|
||||||
std::unique_ptr<DMatrix> dmat{DMatrix::Create(
|
|
||||||
static_cast<DataIterHandle>(&iter), iter.Proxy(), Reset, Next,
|
|
||||||
std::numeric_limits<float>::quiet_NaN(), omp_get_max_threads(), prefix, false)};
|
|
||||||
|
|
||||||
auto row_page_path =
|
|
||||||
data::MakeId(prefix, dynamic_cast<data::SparsePageDMatrix*>(dmat.get())) + ".row.page";
|
|
||||||
EXPECT_TRUE(FileExists(row_page_path)) << row_page_path;
|
|
||||||
|
|
||||||
// Loop over the batches and count the number of pages
|
|
||||||
int64_t batch_count = 0;
|
|
||||||
int64_t row_count = 0;
|
|
||||||
for (const auto& batch : dmat->GetBatches<xgboost::SparsePage>()) {
|
|
||||||
batch_count++;
|
|
||||||
row_count += batch.Size();
|
|
||||||
}
|
|
||||||
|
|
||||||
EXPECT_GE(batch_count, n_batches);
|
|
||||||
EXPECT_EQ(row_count, dmat->Info().num_row_);
|
|
||||||
return dmat;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::unique_ptr<DMatrix> CreateSparsePageDMatrix(size_t n_entries,
|
|
||||||
std::string prefix) {
|
|
||||||
size_t n_columns = 3;
|
|
||||||
size_t n_rows = n_entries / n_columns;
|
|
||||||
NumpyArrayIterForTest iter(0, n_rows, n_columns, 2);
|
|
||||||
|
|
||||||
std::unique_ptr<DMatrix> dmat{
|
|
||||||
DMatrix::Create(static_cast<DataIterHandle>(&iter), iter.Proxy(), Reset, Next,
|
|
||||||
std::numeric_limits<float>::quiet_NaN(), 0, prefix, false)};
|
|
||||||
auto row_page_path =
|
|
||||||
data::MakeId(prefix,
|
|
||||||
dynamic_cast<data::SparsePageDMatrix *>(dmat.get())) +
|
|
||||||
".row.page";
|
|
||||||
EXPECT_TRUE(FileExists(row_page_path)) << row_page_path;
|
|
||||||
|
|
||||||
// Loop over the batches and count the records
|
|
||||||
int64_t batch_count = 0;
|
|
||||||
int64_t row_count = 0;
|
|
||||||
for (const auto &batch : dmat->GetBatches<xgboost::SparsePage>()) {
|
|
||||||
batch_count++;
|
|
||||||
row_count += batch.Size();
|
|
||||||
}
|
|
||||||
EXPECT_GE(batch_count, 2);
|
|
||||||
EXPECT_EQ(row_count, dmat->Info().num_row_);
|
|
||||||
return dmat;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::unique_ptr<DMatrix> CreateSparsePageDMatrixWithRC(size_t n_rows, size_t n_cols,
|
|
||||||
size_t page_size, bool deterministic,
|
|
||||||
const dmlc::TemporaryDirectory& tempdir) {
|
|
||||||
if (!n_rows || !n_cols) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create the svm file in a temp dir
|
|
||||||
const std::string tmp_file = tempdir.path + "/big.libsvm";
|
|
||||||
|
|
||||||
std::ofstream fo(tmp_file.c_str());
|
|
||||||
size_t cols_per_row = ((std::max(n_rows, n_cols) - 1) / std::min(n_rows, n_cols)) + 1;
|
|
||||||
int64_t rem_cols = n_cols;
|
|
||||||
size_t col_idx = 0;
|
|
||||||
|
|
||||||
// Random feature id generator
|
|
||||||
std::random_device rdev;
|
|
||||||
std::unique_ptr<std::mt19937> gen;
|
|
||||||
if (deterministic) {
|
|
||||||
// Seed it with a constant value for this configuration - without getting too fancy
|
|
||||||
// like ordered pairing functions and its likes to make it truely unique
|
|
||||||
gen.reset(new std::mt19937(n_rows * n_cols));
|
|
||||||
} else {
|
|
||||||
gen.reset(new std::mt19937(rdev()));
|
|
||||||
}
|
|
||||||
std::uniform_int_distribution<size_t> label(0, 1);
|
|
||||||
std::uniform_int_distribution<size_t> dis(1, n_cols);
|
|
||||||
|
|
||||||
for (size_t i = 0; i < n_rows; ++i) {
|
|
||||||
// Make sure that all cols are slotted in the first few rows; randomly distribute the
|
|
||||||
// rest
|
|
||||||
std::stringstream row_data;
|
|
||||||
size_t j = 0;
|
|
||||||
if (rem_cols > 0) {
|
|
||||||
for (; j < std::min(static_cast<size_t>(rem_cols), cols_per_row); ++j) {
|
|
||||||
row_data << label(*gen) << " " << (col_idx + j) << ":"
|
|
||||||
<< (col_idx + j + 1) * 10 * i;
|
|
||||||
}
|
|
||||||
rem_cols -= cols_per_row;
|
|
||||||
} else {
|
|
||||||
// Take some random number of colums in [1, n_cols] and slot them here
|
|
||||||
std::vector<size_t> random_columns;
|
|
||||||
size_t ncols = dis(*gen);
|
|
||||||
for (; j < ncols; ++j) {
|
|
||||||
size_t fid = (col_idx + j) % n_cols;
|
|
||||||
random_columns.push_back(fid);
|
|
||||||
}
|
|
||||||
std::sort(random_columns.begin(), random_columns.end());
|
|
||||||
for (auto fid : random_columns) {
|
|
||||||
row_data << label(*gen) << " " << fid << ":" << (fid + 1) * 10 * i;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
col_idx += j;
|
|
||||||
|
|
||||||
fo << row_data.str() << "\n";
|
|
||||||
}
|
|
||||||
fo.close();
|
|
||||||
|
|
||||||
std::string uri = tmp_file + "?format=libsvm";
|
|
||||||
if (page_size > 0) {
|
|
||||||
uri += "#" + tmp_file + ".cache";
|
|
||||||
}
|
|
||||||
std::unique_ptr<DMatrix> dmat(DMatrix::Load(uri));
|
|
||||||
return dmat;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::unique_ptr<GradientBooster> CreateTrainedGBM(std::string name, Args kwargs, size_t kRows,
|
std::unique_ptr<GradientBooster> CreateTrainedGBM(std::string name, Args kwargs, size_t kRows,
|
||||||
size_t kCols,
|
size_t kCols,
|
||||||
LearnerModelParam const* learner_model_param,
|
LearnerModelParam const* learner_model_param,
|
||||||
|
|||||||
@ -3,12 +3,9 @@
|
|||||||
*/
|
*/
|
||||||
#include <xgboost/c_api.h>
|
#include <xgboost/c_api.h>
|
||||||
|
|
||||||
#include "../../src/data/device_adapter.cuh"
|
|
||||||
#include "../../src/data/iterative_dmatrix.h"
|
|
||||||
#include "helpers.h"
|
#include "helpers.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
|
|
||||||
CudaArrayIterForTest::CudaArrayIterForTest(float sparsity, size_t rows,
|
CudaArrayIterForTest::CudaArrayIterForTest(float sparsity, size_t rows,
|
||||||
size_t cols, size_t batches)
|
size_t cols, size_t batches)
|
||||||
: ArrayIterForTest{sparsity, rows, cols, batches} {
|
: ArrayIterForTest{sparsity, rows, cols, batches} {
|
||||||
@ -26,14 +23,4 @@ int CudaArrayIterForTest::Next() {
|
|||||||
iter_++;
|
iter_++;
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<DMatrix> RandomDataGenerator::GenerateDeviceDMatrix(bool with_label) {
|
|
||||||
CudaArrayIterForTest iter{this->sparsity_, this->rows_, this->cols_, 1};
|
|
||||||
auto m = std::make_shared<data::IterativeDMatrix>(
|
|
||||||
&iter, iter.Proxy(), nullptr, Reset, Next, std::numeric_limits<float>::quiet_NaN(), 0, bins_);
|
|
||||||
if (with_label) {
|
|
||||||
this->GenerateLabels(m);
|
|
||||||
}
|
|
||||||
return m;
|
|
||||||
}
|
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -324,9 +324,6 @@ class RandomDataGenerator {
|
|||||||
[[nodiscard]] std::shared_ptr<DMatrix> GenerateExtMemQuantileDMatrix(std::string prefix,
|
[[nodiscard]] std::shared_ptr<DMatrix> GenerateExtMemQuantileDMatrix(std::string prefix,
|
||||||
bool with_label) const;
|
bool with_label) const;
|
||||||
|
|
||||||
#if defined(XGBOOST_USE_CUDA)
|
|
||||||
std::shared_ptr<DMatrix> GenerateDeviceDMatrix(bool with_label);
|
|
||||||
#endif
|
|
||||||
std::shared_ptr<DMatrix> GenerateQuantileDMatrix(bool with_label);
|
std::shared_ptr<DMatrix> GenerateQuantileDMatrix(bool with_label);
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -350,45 +347,6 @@ inline std::vector<float> GenerateRandomCategoricalSingleColumn(int n, size_t nu
|
|||||||
std::shared_ptr<DMatrix> GetDMatrixFromData(const std::vector<float>& x, std::size_t num_rows,
|
std::shared_ptr<DMatrix> GetDMatrixFromData(const std::vector<float>& x, std::size_t num_rows,
|
||||||
bst_feature_t num_columns);
|
bst_feature_t num_columns);
|
||||||
|
|
||||||
/**
|
|
||||||
* \brief Create Sparse Page using data iterator.
|
|
||||||
*
|
|
||||||
* \param n_samples Total number of rows for all batches combined.
|
|
||||||
* \param n_features Number of features
|
|
||||||
* \param n_batches Number of batches
|
|
||||||
* \param prefix Cache prefix, can be used for specifying file path.
|
|
||||||
*
|
|
||||||
* \return A Sparse DMatrix with n_batches.
|
|
||||||
*/
|
|
||||||
std::unique_ptr<DMatrix> CreateSparsePageDMatrix(bst_idx_t n_samples, bst_feature_t n_features,
|
|
||||||
size_t n_batches, std::string prefix = "cache");
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Deprecated, stop using it
|
|
||||||
*/
|
|
||||||
std::unique_ptr<DMatrix> CreateSparsePageDMatrix(size_t n_entries, std::string prefix = "cache");
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Deprecated, stop using it
|
|
||||||
*
|
|
||||||
* \brief Creates dmatrix with some records, each record containing random number of
|
|
||||||
* features in [1, n_cols]
|
|
||||||
*
|
|
||||||
* \param n_rows Number of records to create.
|
|
||||||
* \param n_cols Max number of features within that record.
|
|
||||||
* \param page_size Sparse page size for the pages within the dmatrix. If page size is 0
|
|
||||||
* then the entire dmatrix is resident in memory; else, multiple sparse pages
|
|
||||||
* of page size are created and backed to disk, which would have to be
|
|
||||||
* streamed in at point of use.
|
|
||||||
* \param deterministic The content inside the dmatrix is constant for this configuration, if true;
|
|
||||||
* else, the content changes every time this method is invoked
|
|
||||||
*
|
|
||||||
* \return The new dmatrix.
|
|
||||||
*/
|
|
||||||
std::unique_ptr<DMatrix> CreateSparsePageDMatrixWithRC(
|
|
||||||
size_t n_rows, size_t n_cols, size_t page_size, bool deterministic,
|
|
||||||
const dmlc::TemporaryDirectory& tempdir = dmlc::TemporaryDirectory());
|
|
||||||
|
|
||||||
std::unique_ptr<GradientBooster> CreateTrainedGBM(std::string name, Args kwargs, size_t kRows,
|
std::unique_ptr<GradientBooster> CreateTrainedGBM(std::string name, Args kwargs, size_t kRows,
|
||||||
size_t kCols,
|
size_t kCols,
|
||||||
LearnerModelParam const* learner_model_param,
|
LearnerModelParam const* learner_model_param,
|
||||||
|
|||||||
@ -36,9 +36,10 @@ TEST(SyclPredictor, ExternalMemory) {
|
|||||||
Context ctx;
|
Context ctx;
|
||||||
ctx.UpdateAllowUnknown(Args{{"device", "sycl"}});
|
ctx.UpdateAllowUnknown(Args{{"device", "sycl"}});
|
||||||
|
|
||||||
size_t constexpr kPageSize = 64, kEntriesPerCol = 3;
|
bst_idx_t constexpr kRows{64};
|
||||||
size_t constexpr kEntries = kPageSize * kEntriesPerCol * 2;
|
bst_feature_t constexpr kCols{12};
|
||||||
std::unique_ptr<DMatrix> dmat = CreateSparsePageDMatrix(kEntries);
|
auto dmat =
|
||||||
|
RandomDataGenerator{kRows, kCols, 0.5f}.Batches(3).GenerateSparsePageDMatrix("temp", true);
|
||||||
TestBasic(dmat.get(), &ctx);
|
TestBasic(dmat.get(), &ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -10,12 +10,10 @@
|
|||||||
#include "../../../src/gbm/gbtree.h"
|
#include "../../../src/gbm/gbtree.h"
|
||||||
#include "../../../src/gbm/gbtree_model.h"
|
#include "../../../src/gbm/gbtree_model.h"
|
||||||
#include "../collective/test_worker.h" // for TestDistributedGlobal
|
#include "../collective/test_worker.h" // for TestDistributedGlobal
|
||||||
#include "../filesystem.h" // dmlc::TemporaryDirectory
|
|
||||||
#include "../helpers.h"
|
#include "../helpers.h"
|
||||||
#include "test_predictor.h"
|
#include "test_predictor.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
|
|
||||||
TEST(CpuPredictor, Basic) {
|
TEST(CpuPredictor, Basic) {
|
||||||
Context ctx;
|
Context ctx;
|
||||||
size_t constexpr kRows = 5;
|
size_t constexpr kRows = 5;
|
||||||
@ -56,9 +54,10 @@ TEST(CpuPredictor, IterationRangeColmnSplit) {
|
|||||||
|
|
||||||
TEST(CpuPredictor, ExternalMemory) {
|
TEST(CpuPredictor, ExternalMemory) {
|
||||||
Context ctx;
|
Context ctx;
|
||||||
size_t constexpr kPageSize = 64, kEntriesPerCol = 3;
|
bst_idx_t constexpr kRows{64};
|
||||||
size_t constexpr kEntries = kPageSize * kEntriesPerCol * 2;
|
bst_feature_t constexpr kCols{12};
|
||||||
std::unique_ptr<DMatrix> dmat = CreateSparsePageDMatrix(kEntries);
|
auto dmat =
|
||||||
|
RandomDataGenerator{kRows, kCols, 0.5f}.Batches(3).GenerateSparsePageDMatrix("temp", true);
|
||||||
TestBasic(dmat.get(), &ctx);
|
TestBasic(dmat.get(), &ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -123,8 +123,8 @@ TEST(GPUPredictor, EllpackBasic) {
|
|||||||
size_t rows = bins * 16;
|
size_t rows = bins * 16;
|
||||||
auto p_m = RandomDataGenerator{rows, kCols, 0.0}
|
auto p_m = RandomDataGenerator{rows, kCols, 0.0}
|
||||||
.Bins(bins)
|
.Bins(bins)
|
||||||
.Device(DeviceOrd::CUDA(0))
|
.Device(ctx.Device())
|
||||||
.GenerateDeviceDMatrix(false);
|
.GenerateQuantileDMatrix(false);
|
||||||
ASSERT_FALSE(p_m->PageExists<SparsePage>());
|
ASSERT_FALSE(p_m->PageExists<SparsePage>());
|
||||||
TestPredictionFromGradientIndex<EllpackPage>(&ctx, rows, kCols, p_m);
|
TestPredictionFromGradientIndex<EllpackPage>(&ctx, rows, kCols, p_m);
|
||||||
TestPredictionFromGradientIndex<EllpackPage>(&ctx, bins, kCols, p_m);
|
TestPredictionFromGradientIndex<EllpackPage>(&ctx, bins, kCols, p_m);
|
||||||
@ -137,7 +137,7 @@ TEST(GPUPredictor, EllpackTraining) {
|
|||||||
auto p_ellpack = RandomDataGenerator{kRows, kCols, 0.0}
|
auto p_ellpack = RandomDataGenerator{kRows, kCols, 0.0}
|
||||||
.Bins(kBins)
|
.Bins(kBins)
|
||||||
.Device(ctx.Device())
|
.Device(ctx.Device())
|
||||||
.GenerateDeviceDMatrix(false);
|
.GenerateQuantileDMatrix(false);
|
||||||
HostDeviceVector<float> storage(kRows * kCols);
|
HostDeviceVector<float> storage(kRows * kCols);
|
||||||
auto columnar =
|
auto columnar =
|
||||||
RandomDataGenerator{kRows, kCols, 0.0}.Device(ctx.Device()).GenerateArrayInterface(&storage);
|
RandomDataGenerator{kRows, kCols, 0.0}.Device(ctx.Device()).GenerateArrayInterface(&storage);
|
||||||
|
|||||||
@ -117,24 +117,15 @@ TEST(Learner, CheckGroup) {
|
|||||||
EXPECT_ANY_THROW(learner->UpdateOneIter(0, p_mat));
|
EXPECT_ANY_THROW(learner->UpdateOneIter(0, p_mat));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(Learner, SLOW_CheckMultiBatch) { // NOLINT
|
TEST(Learner, CheckMultiBatch) {
|
||||||
// Create sufficiently large data to make two row pages
|
auto p_fmat =
|
||||||
dmlc::TemporaryDirectory tempdir;
|
RandomDataGenerator{512, 128, 0.8}.Batches(4).GenerateSparsePageDMatrix("temp", true);
|
||||||
const std::string tmp_file = tempdir.path + "/big.libsvm";
|
ASSERT_FALSE(p_fmat->SingleColBlock());
|
||||||
CreateBigTestData(tmp_file, 50000);
|
|
||||||
std::shared_ptr<DMatrix> dmat(
|
std::vector<std::shared_ptr<DMatrix>> mat{p_fmat};
|
||||||
xgboost::DMatrix::Load(tmp_file + "?format=libsvm" + "#" + tmp_file + ".cache"));
|
|
||||||
EXPECT_FALSE(dmat->SingleColBlock());
|
|
||||||
size_t num_row = dmat->Info().num_row_;
|
|
||||||
std::vector<bst_float> labels(num_row);
|
|
||||||
for (size_t i = 0; i < num_row; ++i) {
|
|
||||||
labels[i] = i % 2;
|
|
||||||
}
|
|
||||||
dmat->SetInfo("label", Make1dInterfaceTest(labels.data(), num_row));
|
|
||||||
std::vector<std::shared_ptr<DMatrix>> mat{dmat};
|
|
||||||
auto learner = std::unique_ptr<Learner>(Learner::Create(mat));
|
auto learner = std::unique_ptr<Learner>(Learner::Create(mat));
|
||||||
learner->SetParams(Args{{"objective", "binary:logistic"}});
|
learner->SetParams(Args{{"objective", "binary:logistic"}});
|
||||||
learner->UpdateOneIter(0, dmat);
|
learner->UpdateOneIter(0, p_fmat);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(Learner, Configuration) {
|
TEST(Learner, Configuration) {
|
||||||
|
|||||||
@ -7,22 +7,18 @@
|
|||||||
#include "../../../../src/tree/gpu_hist/gradient_based_sampler.cuh"
|
#include "../../../../src/tree/gpu_hist/gradient_based_sampler.cuh"
|
||||||
#include "../../../../src/tree/param.h"
|
#include "../../../../src/tree/param.h"
|
||||||
#include "../../../../src/tree/param.h" // TrainParam
|
#include "../../../../src/tree/param.h" // TrainParam
|
||||||
#include "../../filesystem.h" // dmlc::TemporaryDirectory
|
|
||||||
#include "../../helpers.h"
|
#include "../../helpers.h"
|
||||||
|
|
||||||
namespace xgboost::tree {
|
namespace xgboost::tree {
|
||||||
void VerifySampling(size_t page_size,
|
void VerifySampling(size_t page_size, float subsample, int sampling_method,
|
||||||
float subsample,
|
bool fixed_size_sampling = true, bool check_sum = true) {
|
||||||
int sampling_method,
|
|
||||||
bool fixed_size_sampling = true,
|
|
||||||
bool check_sum = true) {
|
|
||||||
constexpr size_t kRows = 4096;
|
constexpr size_t kRows = 4096;
|
||||||
constexpr size_t kCols = 1;
|
constexpr size_t kCols = 1;
|
||||||
size_t sample_rows = kRows * subsample;
|
bst_idx_t sample_rows = kRows * subsample;
|
||||||
|
bst_idx_t n_batches = fixed_size_sampling ? 1 : 4;
|
||||||
|
|
||||||
dmlc::TemporaryDirectory tmpdir;
|
auto dmat = RandomDataGenerator{kRows, kCols, 0.0f}.Batches(n_batches).GenerateSparsePageDMatrix(
|
||||||
std::unique_ptr<DMatrix> dmat(CreateSparsePageDMatrix(
|
"temp", true);
|
||||||
kRows, kCols, kRows / (page_size == 0 ? kRows : page_size), tmpdir.path + "/cache"));
|
|
||||||
auto gpair = GenerateRandomGradients(kRows);
|
auto gpair = GenerateRandomGradients(kRows);
|
||||||
GradientPair sum_gpair{};
|
GradientPair sum_gpair{};
|
||||||
for (const auto& gp : gpair.ConstHostVector()) {
|
for (const auto& gp : gpair.ConstHostVector()) {
|
||||||
@ -78,14 +74,12 @@ TEST(GradientBasedSampler, NoSamplingExternalMemory) {
|
|||||||
constexpr size_t kRows = 2048;
|
constexpr size_t kRows = 2048;
|
||||||
constexpr size_t kCols = 1;
|
constexpr size_t kCols = 1;
|
||||||
constexpr float kSubsample = 1.0f;
|
constexpr float kSubsample = 1.0f;
|
||||||
constexpr size_t kPageSize = 1024;
|
|
||||||
|
|
||||||
// Create a DMatrix with multiple batches.
|
// Create a DMatrix with multiple batches.
|
||||||
dmlc::TemporaryDirectory tmpdir;
|
auto dmat =
|
||||||
std::unique_ptr<DMatrix> dmat(
|
RandomDataGenerator{kRows, kCols, 0.0f}.Batches(4).GenerateSparsePageDMatrix("temp", true);
|
||||||
CreateSparsePageDMatrix(kRows, kCols, kRows / kPageSize, tmpdir.path + "/cache"));
|
|
||||||
auto gpair = GenerateRandomGradients(kRows);
|
auto gpair = GenerateRandomGradients(kRows);
|
||||||
Context ctx{MakeCUDACtx(0)};
|
auto ctx = MakeCUDACtx(0);
|
||||||
gpair.SetDevice(ctx.Device());
|
gpair.SetDevice(ctx.Device());
|
||||||
|
|
||||||
auto param = BatchParam{256, tree::TrainParam::DftSparseThreshold()};
|
auto param = BatchParam{256, tree::TrainParam::DftSparseThreshold()};
|
||||||
|
|||||||
@ -406,7 +406,8 @@ namespace {
|
|||||||
void TestHistogramExternalMemory(Context const *ctx, BatchParam batch_param, bool is_approx,
|
void TestHistogramExternalMemory(Context const *ctx, BatchParam batch_param, bool is_approx,
|
||||||
bool force_read_by_column) {
|
bool force_read_by_column) {
|
||||||
size_t constexpr kEntries = 1 << 16;
|
size_t constexpr kEntries = 1 << 16;
|
||||||
auto m = CreateSparsePageDMatrix(kEntries, "cache");
|
auto m =
|
||||||
|
RandomDataGenerator{kEntries / 8, 8, 0.0f}.Batches(4).GenerateSparsePageDMatrix("temp", true);
|
||||||
|
|
||||||
std::vector<float> hess(m->Info().num_row_, 1.0);
|
std::vector<float> hess(m->Info().num_row_, 1.0);
|
||||||
if (is_approx) {
|
if (is_approx) {
|
||||||
|
|||||||
@ -17,12 +17,11 @@
|
|||||||
#include "../../../src/common/random.h" // for GlobalRandom
|
#include "../../../src/common/random.h" // for GlobalRandom
|
||||||
#include "../../../src/tree/param.h" // for TrainParam
|
#include "../../../src/tree/param.h" // for TrainParam
|
||||||
#include "../collective/test_worker.h" // for BaseMGPUTest
|
#include "../collective/test_worker.h" // for BaseMGPUTest
|
||||||
#include "../filesystem.h" // dmlc::TemporaryDirectory
|
|
||||||
#include "../helpers.h"
|
#include "../helpers.h"
|
||||||
|
|
||||||
namespace xgboost::tree {
|
namespace xgboost::tree {
|
||||||
namespace {
|
namespace {
|
||||||
void UpdateTree(Context const* ctx, linalg::Matrix<GradientPair>* gpair, DMatrix* dmat, bool is_ext,
|
void UpdateTree(Context const* ctx, linalg::Matrix<GradientPair>* gpair, DMatrix* dmat,
|
||||||
RegTree* tree, HostDeviceVector<bst_float>* preds, float subsample,
|
RegTree* tree, HostDeviceVector<bst_float>* preds, float subsample,
|
||||||
const std::string& sampling_method, bst_bin_t max_bin) {
|
const std::string& sampling_method, bst_bin_t max_bin) {
|
||||||
Args args{
|
Args args{
|
||||||
@ -45,7 +44,7 @@ void UpdateTree(Context const* ctx, linalg::Matrix<GradientPair>* gpair, DMatrix
|
|||||||
hist_maker->Update(¶m, gpair, dmat, common::Span<HostDeviceVector<bst_node_t>>{position},
|
hist_maker->Update(¶m, gpair, dmat, common::Span<HostDeviceVector<bst_node_t>>{position},
|
||||||
{tree});
|
{tree});
|
||||||
auto cache = linalg::MakeTensorView(ctx, preds->DeviceSpan(), preds->Size(), 1);
|
auto cache = linalg::MakeTensorView(ctx, preds->DeviceSpan(), preds->Size(), 1);
|
||||||
if (subsample < 1.0 && is_ext) {
|
if (subsample < 1.0 && !dmat->SingleColBlock()) {
|
||||||
ASSERT_FALSE(hist_maker->UpdatePredictionCache(dmat, cache));
|
ASSERT_FALSE(hist_maker->UpdatePredictionCache(dmat, cache));
|
||||||
} else {
|
} else {
|
||||||
ASSERT_TRUE(hist_maker->UpdatePredictionCache(dmat, cache));
|
ASSERT_TRUE(hist_maker->UpdatePredictionCache(dmat, cache));
|
||||||
@ -58,22 +57,23 @@ TEST(GpuHist, UniformSampling) {
|
|||||||
constexpr size_t kCols = 2;
|
constexpr size_t kCols = 2;
|
||||||
constexpr float kSubsample = 0.9999;
|
constexpr float kSubsample = 0.9999;
|
||||||
common::GlobalRandom().seed(1994);
|
common::GlobalRandom().seed(1994);
|
||||||
|
auto ctx = MakeCUDACtx(0);
|
||||||
|
|
||||||
// Create an in-memory DMatrix.
|
// Create an in-memory DMatrix.
|
||||||
std::unique_ptr<DMatrix> dmat(CreateSparsePageDMatrixWithRC(kRows, kCols, 0, true));
|
auto p_fmat = RandomDataGenerator{kRows, kCols, 0.0f}.GenerateDMatrix(true);
|
||||||
|
ASSERT_TRUE(p_fmat->SingleColBlock());
|
||||||
|
|
||||||
linalg::Matrix<GradientPair> gpair({kRows}, Context{}.MakeCUDA().Device());
|
linalg::Matrix<GradientPair> gpair({kRows}, ctx.Device());
|
||||||
gpair.Data()->Copy(GenerateRandomGradients(kRows));
|
gpair.Data()->Copy(GenerateRandomGradients(kRows));
|
||||||
|
|
||||||
// Build a tree using the in-memory DMatrix.
|
// Build a tree using the in-memory DMatrix.
|
||||||
RegTree tree;
|
RegTree tree;
|
||||||
HostDeviceVector<bst_float> preds(kRows, 0.0, DeviceOrd::CUDA(0));
|
HostDeviceVector<bst_float> preds(kRows, 0.0, ctx.Device());
|
||||||
Context ctx(MakeCUDACtx(0));
|
UpdateTree(&ctx, &gpair, p_fmat.get(), &tree, &preds, 1.0, "uniform", kRows);
|
||||||
UpdateTree(&ctx, &gpair, dmat.get(), false, &tree, &preds, 1.0, "uniform", kRows);
|
|
||||||
// Build another tree using sampling.
|
// Build another tree using sampling.
|
||||||
RegTree tree_sampling;
|
RegTree tree_sampling;
|
||||||
HostDeviceVector<bst_float> preds_sampling(kRows, 0.0, DeviceOrd::CUDA(0));
|
HostDeviceVector<bst_float> preds_sampling(kRows, 0.0, ctx.Device());
|
||||||
UpdateTree(&ctx, &gpair, dmat.get(), false, &tree_sampling, &preds_sampling, kSubsample, "uniform",
|
UpdateTree(&ctx, &gpair, p_fmat.get(), &tree_sampling, &preds_sampling, kSubsample, "uniform",
|
||||||
kRows);
|
kRows);
|
||||||
|
|
||||||
// Make sure the predictions are the same.
|
// Make sure the predictions are the same.
|
||||||
@ -89,23 +89,23 @@ TEST(GpuHist, GradientBasedSampling) {
|
|||||||
constexpr size_t kCols = 2;
|
constexpr size_t kCols = 2;
|
||||||
constexpr float kSubsample = 0.9999;
|
constexpr float kSubsample = 0.9999;
|
||||||
common::GlobalRandom().seed(1994);
|
common::GlobalRandom().seed(1994);
|
||||||
|
auto ctx = MakeCUDACtx(0);
|
||||||
|
|
||||||
// Create an in-memory DMatrix.
|
// Create an in-memory DMatrix.
|
||||||
std::unique_ptr<DMatrix> dmat(CreateSparsePageDMatrixWithRC(kRows, kCols, 0, true));
|
auto p_fmat = RandomDataGenerator{kRows, kCols, 0.0f}.GenerateDMatrix(true);
|
||||||
|
|
||||||
linalg::Matrix<GradientPair> gpair({kRows}, MakeCUDACtx(0).Device());
|
linalg::Matrix<GradientPair> gpair({kRows}, ctx.Device());
|
||||||
gpair.Data()->Copy(GenerateRandomGradients(kRows));
|
gpair.Data()->Copy(GenerateRandomGradients(kRows));
|
||||||
|
|
||||||
// Build a tree using the in-memory DMatrix.
|
// Build a tree using the in-memory DMatrix.
|
||||||
RegTree tree;
|
RegTree tree;
|
||||||
HostDeviceVector<bst_float> preds(kRows, 0.0, DeviceOrd::CUDA(0));
|
HostDeviceVector<bst_float> preds(kRows, 0.0, ctx.Device());
|
||||||
Context ctx(MakeCUDACtx(0));
|
UpdateTree(&ctx, &gpair, p_fmat.get(), &tree, &preds, 1.0, "uniform", kRows);
|
||||||
UpdateTree(&ctx, &gpair, dmat.get(), false, &tree, &preds, 1.0, "uniform", kRows);
|
|
||||||
|
|
||||||
// Build another tree using sampling.
|
// Build another tree using sampling.
|
||||||
RegTree tree_sampling;
|
RegTree tree_sampling;
|
||||||
HostDeviceVector<bst_float> preds_sampling(kRows, 0.0, DeviceOrd::CUDA(0));
|
HostDeviceVector<bst_float> preds_sampling(kRows, 0.0, ctx.Device());
|
||||||
UpdateTree(&ctx, &gpair, dmat.get(), false, &tree_sampling, &preds_sampling, kSubsample,
|
UpdateTree(&ctx, &gpair, p_fmat.get(), &tree_sampling, &preds_sampling, kSubsample,
|
||||||
"gradient_based", kRows);
|
"gradient_based", kRows);
|
||||||
|
|
||||||
// Make sure the predictions are the same.
|
// Make sure the predictions are the same.
|
||||||
@ -119,29 +119,29 @@ TEST(GpuHist, GradientBasedSampling) {
|
|||||||
TEST(GpuHist, ExternalMemory) {
|
TEST(GpuHist, ExternalMemory) {
|
||||||
constexpr size_t kRows = 4096;
|
constexpr size_t kRows = 4096;
|
||||||
constexpr size_t kCols = 2;
|
constexpr size_t kCols = 2;
|
||||||
constexpr size_t kPageSize = 1024;
|
|
||||||
|
|
||||||
dmlc::TemporaryDirectory tmpdir;
|
|
||||||
|
|
||||||
// Create a DMatrix with multiple batches.
|
// Create a DMatrix with multiple batches.
|
||||||
std::unique_ptr<DMatrix> dmat_ext(
|
auto p_fmat_ext =
|
||||||
CreateSparsePageDMatrix(kRows, kCols, kRows / kPageSize, tmpdir.path + "/cache"));
|
RandomDataGenerator{kRows, kCols, 0.0f}.Batches(4).GenerateSparsePageDMatrix("temp", true);
|
||||||
|
ASSERT_FALSE(p_fmat_ext->SingleColBlock());
|
||||||
|
|
||||||
// Create a single batch DMatrix.
|
// Create a single batch DMatrix.
|
||||||
std::unique_ptr<DMatrix> dmat(CreateSparsePageDMatrix(kRows, kCols, 1, tmpdir.path + "/cache"));
|
auto p_fmat =
|
||||||
|
RandomDataGenerator{kRows, kCols, 0.0f}.Batches(1).GenerateSparsePageDMatrix("temp", true);
|
||||||
|
ASSERT_TRUE(p_fmat->SingleColBlock());
|
||||||
|
|
||||||
Context ctx(MakeCUDACtx(0));
|
auto ctx = MakeCUDACtx(0);
|
||||||
linalg::Matrix<GradientPair> gpair({kRows}, ctx.Device());
|
linalg::Matrix<GradientPair> gpair({kRows}, ctx.Device());
|
||||||
gpair.Data()->Copy(GenerateRandomGradients(kRows));
|
gpair.Data()->Copy(GenerateRandomGradients(kRows));
|
||||||
|
|
||||||
// Build a tree using the in-memory DMatrix.
|
// Build a tree using the in-memory DMatrix.
|
||||||
RegTree tree;
|
RegTree tree;
|
||||||
HostDeviceVector<bst_float> preds(kRows, 0.0, DeviceOrd::CUDA(0));
|
HostDeviceVector<bst_float> preds(kRows, 0.0, ctx.Device());
|
||||||
UpdateTree(&ctx, &gpair, dmat.get(), false, &tree, &preds, 1.0, "uniform", kRows);
|
UpdateTree(&ctx, &gpair, p_fmat.get(), &tree, &preds, 1.0, "uniform", kRows);
|
||||||
// Build another tree using multiple ELLPACK pages.
|
// Build another tree using multiple ELLPACK pages.
|
||||||
RegTree tree_ext;
|
RegTree tree_ext;
|
||||||
HostDeviceVector<bst_float> preds_ext(kRows, 0.0, DeviceOrd::CUDA(0));
|
HostDeviceVector<bst_float> preds_ext(kRows, 0.0, ctx.Device());
|
||||||
UpdateTree(&ctx, &gpair, dmat_ext.get(), true, &tree_ext, &preds_ext, 1.0, "uniform", kRows);
|
UpdateTree(&ctx, &gpair, p_fmat_ext.get(), &tree_ext, &preds_ext, 1.0, "uniform", kRows);
|
||||||
|
|
||||||
// Make sure the predictions are the same.
|
// Make sure the predictions are the same.
|
||||||
auto preds_h = preds.ConstHostVector();
|
auto preds_h = preds.ConstHostVector();
|
||||||
@ -157,20 +157,21 @@ TEST(GpuHist, ExternalMemoryWithSampling) {
|
|||||||
const std::string kSamplingMethod = "gradient_based";
|
const std::string kSamplingMethod = "gradient_based";
|
||||||
common::GlobalRandom().seed(0);
|
common::GlobalRandom().seed(0);
|
||||||
|
|
||||||
dmlc::TemporaryDirectory tmpdir;
|
auto ctx = MakeCUDACtx(0);
|
||||||
Context ctx(MakeCUDACtx(0));
|
|
||||||
|
|
||||||
// Create a single batch DMatrix.
|
// Create a single batch DMatrix.
|
||||||
auto p_fmat = RandomDataGenerator{kRows, kCols, 0.0f}
|
auto p_fmat = RandomDataGenerator{kRows, kCols, 0.0f}
|
||||||
.Device(ctx.Device())
|
.Device(ctx.Device())
|
||||||
.Batches(1)
|
.Batches(1)
|
||||||
.GenerateSparsePageDMatrix("temp", true);
|
.GenerateSparsePageDMatrix("temp", true);
|
||||||
|
ASSERT_TRUE(p_fmat->SingleColBlock());
|
||||||
|
|
||||||
// Create a DMatrix with multiple batches.
|
// Create a DMatrix with multiple batches.
|
||||||
auto p_fmat_ext = RandomDataGenerator{kRows, kCols, 0.0f}
|
auto p_fmat_ext = RandomDataGenerator{kRows, kCols, 0.0f}
|
||||||
.Device(ctx.Device())
|
.Device(ctx.Device())
|
||||||
.Batches(4)
|
.Batches(4)
|
||||||
.GenerateSparsePageDMatrix("temp", true);
|
.GenerateSparsePageDMatrix("temp", true);
|
||||||
|
ASSERT_FALSE(p_fmat_ext->SingleColBlock());
|
||||||
|
|
||||||
linalg::Matrix<GradientPair> gpair({kRows}, ctx.Device());
|
linalg::Matrix<GradientPair> gpair({kRows}, ctx.Device());
|
||||||
gpair.Data()->Copy(GenerateRandomGradients(kRows));
|
gpair.Data()->Copy(GenerateRandomGradients(kRows));
|
||||||
@ -179,26 +180,25 @@ TEST(GpuHist, ExternalMemoryWithSampling) {
|
|||||||
auto rng = common::GlobalRandom();
|
auto rng = common::GlobalRandom();
|
||||||
|
|
||||||
RegTree tree;
|
RegTree tree;
|
||||||
HostDeviceVector<bst_float> preds(kRows, 0.0, DeviceOrd::CUDA(0));
|
HostDeviceVector<bst_float> preds(kRows, 0.0, ctx.Device());
|
||||||
UpdateTree(&ctx, &gpair, p_fmat.get(), true, &tree, &preds, kSubsample, kSamplingMethod, kRows);
|
UpdateTree(&ctx, &gpair, p_fmat.get(), &tree, &preds, kSubsample, kSamplingMethod, kRows);
|
||||||
|
|
||||||
// Build another tree using multiple ELLPACK pages.
|
// Build another tree using multiple ELLPACK pages.
|
||||||
common::GlobalRandom() = rng;
|
common::GlobalRandom() = rng;
|
||||||
RegTree tree_ext;
|
RegTree tree_ext;
|
||||||
HostDeviceVector<bst_float> preds_ext(kRows, 0.0, DeviceOrd::CUDA(0));
|
HostDeviceVector<bst_float> preds_ext(kRows, 0.0, ctx.Device());
|
||||||
UpdateTree(&ctx, &gpair, p_fmat_ext.get(), true, &tree_ext, &preds_ext, kSubsample,
|
UpdateTree(&ctx, &gpair, p_fmat_ext.get(), &tree_ext, &preds_ext, kSubsample, kSamplingMethod,
|
||||||
kSamplingMethod, kRows);
|
kRows);
|
||||||
|
|
||||||
// Make sure the predictions are the same.
|
Json jtree{Object{}};
|
||||||
auto preds_h = preds.ConstHostVector();
|
Json jtree_ext{Object{}};
|
||||||
auto preds_ext_h = preds_ext.ConstHostVector();
|
tree.SaveModel(&jtree);
|
||||||
for (size_t i = 0; i < kRows; i++) {
|
tree_ext.SaveModel(&jtree_ext);
|
||||||
ASSERT_NEAR(preds_h[i], preds_ext_h[i], 1e-3);
|
ASSERT_EQ(jtree, jtree_ext);
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(GpuHist, ConfigIO) {
|
TEST(GpuHist, ConfigIO) {
|
||||||
Context ctx(MakeCUDACtx(0));
|
auto ctx = MakeCUDACtx(0);
|
||||||
ObjInfo task{ObjInfo::kRegression};
|
ObjInfo task{ObjInfo::kRegression};
|
||||||
std::unique_ptr<TreeUpdater> updater{TreeUpdater::Create("grow_gpu_hist", &ctx, &task)};
|
std::unique_ptr<TreeUpdater> updater{TreeUpdater::Create("grow_gpu_hist", &ctx, &task)};
|
||||||
updater->Configure(Args{});
|
updater->Configure(Args{});
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user