Use Booster context in DMatrix. (#8896)

- Pass context from booster to DMatrix.
- Use context instead of integer for `n_threads`.
- Check the consistency configuration for `max_bin`.
- Test for all combinations of initialization options.
This commit is contained in:
Jiaming Yuan
2023-04-28 21:47:14 +08:00
committed by GitHub
parent 1f9a57d17b
commit 08ce495b5d
67 changed files with 1283 additions and 935 deletions

View File

@@ -1,12 +1,14 @@
/*!
* Copyright 2019-2021 by XGBoost Contributors
/**
* Copyright 2019-2023, XGBoost Contributors
* \file simple_dmatrix.cu
*/
#include <thrust/copy.h>
#include <xgboost/data.h>
#include "device_adapter.cuh" // for CurrentDevice
#include "simple_dmatrix.cuh"
#include "simple_dmatrix.h"
#include "device_adapter.cuh"
#include "xgboost/context.h" // for Context
#include "xgboost/data.h"
namespace xgboost {
namespace data {
@@ -15,7 +17,7 @@ namespace data {
// Current implementation assumes a single batch. More batches can
// be supported in future. Does not currently support inferring row/column size
template <typename AdapterT>
SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int32_t /*nthread*/,
SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, std::int32_t nthread,
DataSplitMode data_split_mode) {
CHECK(data_split_mode != DataSplitMode::kCol)
<< "Column-wise data split is currently not supported on the GPU.";
@@ -24,6 +26,9 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int32_t /*nthread
CHECK_GE(device, 0);
dh::safe_cuda(cudaSetDevice(device));
Context ctx;
ctx.Init(Args{{"nthread", std::to_string(nthread)}, {"gpu_id", std::to_string(device)}});
CHECK(adapter->NumRows() != kAdapterUnknownSize);
CHECK(adapter->NumColumns() != kAdapterUnknownSize);
@@ -33,13 +38,14 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int32_t /*nthread
// Enforce single batch
CHECK(!adapter->Next());
info_.num_nonzero_ =
CopyToSparsePage(adapter->Value(), device, missing, sparse_page_.get());
info_.num_nonzero_ = CopyToSparsePage(adapter->Value(), device, missing, sparse_page_.get());
info_.num_col_ = adapter->NumColumns();
info_.num_row_ = adapter->NumRows();
// Synchronise worker columns
info_.data_split_mode = data_split_mode;
info_.SynchronizeNumberOfColumns();
this->fmat_ctx_ = ctx;
}
template SimpleDMatrix::SimpleDMatrix(CudfAdapter* adapter, float missing,