xgboost/src/data/simple_dmatrix.cu
Jiaming Yuan 06bdc15e9b
[coll] Pass context to various functions. (#9772)
* [coll] Pass context to various functions.

In the future, the `Context` object would be required for collective operations, this PR
passes the context object to some required functions to prepare for swapping out the
implementation.
2023-11-08 09:54:05 +08:00

55 lines
2.0 KiB
Plaintext

/**
* Copyright 2019-2023, XGBoost Contributors
* \file simple_dmatrix.cu
*/
#include <thrust/copy.h>
#include "device_adapter.cuh" // for CurrentDevice
#include "simple_dmatrix.cuh"
#include "simple_dmatrix.h"
#include "xgboost/context.h" // for Context
#include "xgboost/data.h"
namespace xgboost::data {
// Does not currently support metainfo as no on-device data source contains this
// Current implementation assumes a single batch. More batches can
// be supported in future. Does not currently support inferring row/column size
template <typename AdapterT>
SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, std::int32_t nthread,
DataSplitMode data_split_mode) {
CHECK(data_split_mode != DataSplitMode::kCol)
<< "Column-wise data split is currently not supported on the GPU.";
auto device = (adapter->Device().IsCPU() || adapter->NumRows() == 0)
? DeviceOrd::CUDA(dh::CurrentDevice())
: adapter->Device();
CHECK(device.IsCUDA());
dh::safe_cuda(cudaSetDevice(device.ordinal));
Context ctx;
ctx.Init(Args{{"nthread", std::to_string(nthread)}, {"device", device.Name()}});
CHECK(adapter->NumRows() != kAdapterUnknownSize);
CHECK(adapter->NumColumns() != kAdapterUnknownSize);
adapter->BeforeFirst();
adapter->Next();
// Enforce single batch
CHECK(!adapter->Next());
info_.num_nonzero_ = CopyToSparsePage(adapter->Value(), device, missing, sparse_page_.get());
info_.num_col_ = adapter->NumColumns();
info_.num_row_ = adapter->NumRows();
// Synchronise worker columns
info_.data_split_mode = data_split_mode;
info_.SynchronizeNumberOfColumns(&ctx);
this->fmat_ctx_ = ctx;
}
template SimpleDMatrix::SimpleDMatrix(CudfAdapter* adapter, float missing,
int nthread, DataSplitMode data_split_mode);
template SimpleDMatrix::SimpleDMatrix(CupyAdapter* adapter, float missing,
int nthread, DataSplitMode data_split_mode);
} // namespace xgboost::data