Use Booster context in DMatrix. (#8896)
- Pass context from booster to DMatrix. - Use context instead of integer for `n_threads`. - Check the consistency configuration for `max_bin`. - Test for all combinations of initialization options.
This commit is contained in:
@@ -1,12 +1,14 @@
|
||||
/*!
|
||||
* Copyright 2019-2021 by XGBoost Contributors
|
||||
/**
|
||||
* Copyright 2019-2023, XGBoost Contributors
|
||||
* \file simple_dmatrix.cu
|
||||
*/
|
||||
#include <thrust/copy.h>
|
||||
#include <xgboost/data.h>
|
||||
|
||||
#include "device_adapter.cuh" // for CurrentDevice
|
||||
#include "simple_dmatrix.cuh"
|
||||
#include "simple_dmatrix.h"
|
||||
#include "device_adapter.cuh"
|
||||
#include "xgboost/context.h" // for Context
|
||||
#include "xgboost/data.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace data {
|
||||
@@ -15,7 +17,7 @@ namespace data {
|
||||
// Current implementation assumes a single batch. More batches can
|
||||
// be supported in future. Does not currently support inferring row/column size
|
||||
template <typename AdapterT>
|
||||
SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int32_t /*nthread*/,
|
||||
SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, std::int32_t nthread,
|
||||
DataSplitMode data_split_mode) {
|
||||
CHECK(data_split_mode != DataSplitMode::kCol)
|
||||
<< "Column-wise data split is currently not supported on the GPU.";
|
||||
@@ -24,6 +26,9 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int32_t /*nthread
|
||||
CHECK_GE(device, 0);
|
||||
dh::safe_cuda(cudaSetDevice(device));
|
||||
|
||||
Context ctx;
|
||||
ctx.Init(Args{{"nthread", std::to_string(nthread)}, {"gpu_id", std::to_string(device)}});
|
||||
|
||||
CHECK(adapter->NumRows() != kAdapterUnknownSize);
|
||||
CHECK(adapter->NumColumns() != kAdapterUnknownSize);
|
||||
|
||||
@@ -33,13 +38,14 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int32_t /*nthread
|
||||
// Enforce single batch
|
||||
CHECK(!adapter->Next());
|
||||
|
||||
info_.num_nonzero_ =
|
||||
CopyToSparsePage(adapter->Value(), device, missing, sparse_page_.get());
|
||||
info_.num_nonzero_ = CopyToSparsePage(adapter->Value(), device, missing, sparse_page_.get());
|
||||
info_.num_col_ = adapter->NumColumns();
|
||||
info_.num_row_ = adapter->NumRows();
|
||||
// Synchronise worker columns
|
||||
info_.data_split_mode = data_split_mode;
|
||||
info_.SynchronizeNumberOfColumns();
|
||||
|
||||
this->fmat_ctx_ = ctx;
|
||||
}
|
||||
|
||||
template SimpleDMatrix::SimpleDMatrix(CudfAdapter* adapter, float missing,
|
||||
|
||||
Reference in New Issue
Block a user