xgboost/tests/cpp/helpers.cu
Jiaming Yuan 8d7fe262d9
[EM] Enable access to the number of batches. (#10691)
- Expose `NumBatches` in `DMatrix`.
- Small cleanup for removing legacy CUDA stream and ~force CUDA context initialization~.
- Purge old external memory data generation code.
2024-08-17 02:59:45 +08:00

27 lines
671 B
Plaintext

/**
* Copyright 2020-2024, XGBoost contributors
*/
#include <xgboost/c_api.h>
#include "helpers.h"
namespace xgboost {
CudaArrayIterForTest::CudaArrayIterForTest(float sparsity, size_t rows,
size_t cols, size_t batches)
: ArrayIterForTest{sparsity, rows, cols, batches} {
rng_->Device(FstCU());
std::tie(batches_, interface_) =
rng_->GenerateArrayInterfaceBatch(&data_, n_batches_);
this->Reset();
}
int CudaArrayIterForTest::Next() {
if (iter_ == n_batches_) {
return 0;
}
XGProxyDMatrixSetDataCudaArrayInterface(proxy_, batches_[iter_].c_str());
iter_++;
return 1;
}
} // namespace xgboost