Small cleanups to various data types. (#8086)

- Use `bst_bin_t` in batch param constructor.
- Use `StringView` to avoid `std::string` when appropriate.
- Avoid using `MetaInfo` in quantile constructor to limit the scope of parameter.
This commit is contained in:
Jiaming Yuan 2022-07-18 22:39:36 +08:00 committed by GitHub
parent e28f6f6657
commit 4083440690
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 52 additions and 53 deletions

View File

@ -216,7 +216,7 @@ struct BatchParam {
/*! \brief The GPU device to use. */ /*! \brief The GPU device to use. */
int gpu_id {-1}; int gpu_id {-1};
/*! \brief Maximum number of bins per feature for histograms. */ /*! \brief Maximum number of bins per feature for histograms. */
int max_bin{0}; bst_bin_t max_bin{0};
/*! \brief Hessian, used for sketching with future approx implementation. */ /*! \brief Hessian, used for sketching with future approx implementation. */
common::Span<float> hess; common::Span<float> hess;
/*! \brief Whether should DMatrix regenerate the batch. Only used for GHistIndex. */ /*! \brief Whether should DMatrix regenerate the batch. Only used for GHistIndex. */
@ -226,17 +226,17 @@ struct BatchParam {
BatchParam() = default; BatchParam() = default;
// GPU Hist // GPU Hist
BatchParam(int32_t device, int32_t max_bin) BatchParam(int32_t device, bst_bin_t max_bin)
: gpu_id{device}, max_bin{max_bin} {} : gpu_id{device}, max_bin{max_bin} {}
// Hist // Hist
BatchParam(int32_t max_bin, double sparse_thresh) BatchParam(bst_bin_t max_bin, double sparse_thresh)
: max_bin{max_bin}, sparse_thresh{sparse_thresh} {} : max_bin{max_bin}, sparse_thresh{sparse_thresh} {}
// Approx // Approx
/** /**
* \brief Get batch with sketch weighted by hessian. The batch will be regenerated if * \brief Get batch with sketch weighted by hessian. The batch will be regenerated if
* the span is changed, so caller should keep the span for each iteration. * the span is changed, so caller should keep the span for each iteration.
*/ */
BatchParam(int32_t max_bin, common::Span<float> hessian, bool regenerate) BatchParam(bst_bin_t max_bin, common::Span<float> hessian, bool regenerate)
: max_bin{max_bin}, hess{hessian}, regen{regenerate} {} : max_bin{max_bin}, hess{hessian}, regen{regenerate} {}
bool operator!=(BatchParam const& other) const { bool operator!=(BatchParam const& other) const {

View File

@ -49,14 +49,14 @@ HistogramCuts SketchOnDMatrix(DMatrix *m, int32_t max_bins, int32_t n_threads, b
} }
if (!use_sorted) { if (!use_sorted) {
HostSketchContainer container(max_bins, m->Info(), reduced, HostSketchContainer::UseGroup(info), HostSketchContainer container(max_bins, m->Info().feature_types.ConstHostSpan(), reduced,
n_threads); HostSketchContainer::UseGroup(info), n_threads);
for (auto const& page : m->GetBatches<SparsePage>()) { for (auto const& page : m->GetBatches<SparsePage>()) {
container.PushRowPage(page, info, hessian); container.PushRowPage(page, info, hessian);
} }
container.MakeCuts(&out); container.MakeCuts(&out);
} else { } else {
SortedSketchContainer container{max_bins, m->Info(), reduced, SortedSketchContainer container{max_bins, m->Info().feature_types.ConstHostSpan(), reduced,
HostSketchContainer::UseGroup(info), n_threads}; HostSketchContainer::UseGroup(info), n_threads};
for (auto const& page : m->GetBatches<SortedCSCPage>()) { for (auto const& page : m->GetBatches<SortedCSCPage>()) {
container.PushColPage(page, info, hessian); container.PushColPage(page, info, hessian);

View File

@ -86,7 +86,7 @@ void SketchContainerImpl<WQSketch>::PushRowPage(SparsePage const &page, MetaInfo
template <typename Batch> template <typename Batch>
void HostSketchContainer::PushAdapterBatch(Batch const &batch, size_t base_rowid, void HostSketchContainer::PushAdapterBatch(Batch const &batch, size_t base_rowid,
MetaInfo const &info, size_t nnz, float missing) { MetaInfo const &info, float missing) {
auto const &h_weights = auto const &h_weights =
(use_group_ind_ ? detail::UnrollGroupWeights(info) : info.weights_.HostVector()); (use_group_ind_ ? detail::UnrollGroupWeights(info) : info.weights_.HostVector());
@ -94,14 +94,14 @@ void HostSketchContainer::PushAdapterBatch(Batch const &batch, size_t base_rowid
auto weights = OptionalWeights{Span<float const>{h_weights}}; auto weights = OptionalWeights{Span<float const>{h_weights}};
// the nnz from info is not reliable as sketching might be the first place to go through // the nnz from info is not reliable as sketching might be the first place to go through
// the data. // the data.
auto is_dense = nnz == info.num_col_ * info.num_row_; auto is_dense = info.num_nonzero_ == info.num_col_ * info.num_row_;
this->PushRowPageImpl(batch, base_rowid, weights, nnz, info.num_col_, is_dense, is_valid); this->PushRowPageImpl(batch, base_rowid, weights, info.num_nonzero_, info.num_col_, is_dense,
is_valid);
} }
#define INSTANTIATE(_type) \ #define INSTANTIATE(_type) \
template void HostSketchContainer::PushAdapterBatch<data::_type>( \ template void HostSketchContainer::PushAdapterBatch<data::_type>( \
data::_type const &batch, size_t base_rowid, MetaInfo const &info, size_t nnz, \ data::_type const &batch, size_t base_rowid, MetaInfo const &info, float missing);
float missing);
INSTANTIATE(ArrayAdapterBatch) INSTANTIATE(ArrayAdapterBatch)
INSTANTIATE(CSRArrayAdapterBatch) INSTANTIATE(CSRArrayAdapterBatch)
@ -436,11 +436,10 @@ void SketchContainerImpl<WQSketch>::MakeCuts(HistogramCuts* cuts) {
template class SketchContainerImpl<WQuantileSketch<float, float>>; template class SketchContainerImpl<WQuantileSketch<float, float>>;
template class SketchContainerImpl<WXQuantileSketch<float, float>>; template class SketchContainerImpl<WXQuantileSketch<float, float>>;
HostSketchContainer::HostSketchContainer(int32_t max_bins, MetaInfo const &info, HostSketchContainer::HostSketchContainer(int32_t max_bins, common::Span<FeatureType const> ft,
std::vector<size_t> columns_size, bool use_group, std::vector<size_t> columns_size, bool use_group,
int32_t n_threads) int32_t n_threads)
: SketchContainerImpl{columns_size, max_bins, info.feature_types.ConstHostSpan(), use_group, : SketchContainerImpl{columns_size, max_bins, ft, use_group, n_threads} {
n_threads} {
monitor_.Init(__func__); monitor_.Init(__func__);
ParallelFor(sketches_.size(), n_threads_, Sched::Auto(), [&](auto i) { ParallelFor(sketches_.size(), n_threads_, Sched::Auto(), [&](auto i) {
auto n_bins = std::min(static_cast<size_t>(max_bins_), columns_size_[i]); auto n_bins = std::min(static_cast<size_t>(max_bins_), columns_size_[i]);

View File

@ -903,12 +903,11 @@ class HostSketchContainer : public SketchContainerImpl<WQuantileSketch<float, fl
using WQSketch = WQuantileSketch<float, float>; using WQSketch = WQuantileSketch<float, float>;
public: public:
HostSketchContainer(int32_t max_bins, MetaInfo const &info, std::vector<size_t> columns_size, HostSketchContainer(int32_t max_bins, common::Span<FeatureType const> ft,
bool use_group, int32_t n_threads); std::vector<size_t> columns_size, bool use_group, int32_t n_threads);
template <typename Batch> template <typename Batch>
void PushAdapterBatch(Batch const &batch, size_t base_rowid, MetaInfo const &info, size_t nnz, void PushAdapterBatch(Batch const &batch, size_t base_rowid, MetaInfo const &info, float missing);
float missing);
}; };
/** /**
@ -1000,13 +999,12 @@ class SortedSketchContainer : public SketchContainerImpl<WXQuantileSketch<float,
using Super = SketchContainerImpl<WXQuantileSketch<float, float>>; using Super = SketchContainerImpl<WXQuantileSketch<float, float>>;
public: public:
explicit SortedSketchContainer(int32_t max_bins, MetaInfo const &info, explicit SortedSketchContainer(int32_t max_bins, common::Span<FeatureType const> ft,
std::vector<size_t> columns_size, bool use_group, std::vector<size_t> columns_size, bool use_group,
int32_t n_threads) int32_t n_threads)
: SketchContainerImpl{columns_size, max_bins, info.feature_types.ConstHostSpan(), use_group, : SketchContainerImpl{columns_size, max_bins, ft, use_group, n_threads} {
n_threads} {
monitor_.Init(__func__); monitor_.Init(__func__);
sketches_.resize(info.num_col_); sketches_.resize(columns_size.size());
size_t i = 0; size_t i = 0;
for (auto &sketch : sketches_) { for (auto &sketch : sketches_) {
sketch.sketch = &Super::sketches_[i]; sketch.sketch = &Super::sketches_[i];

View File

@ -1137,16 +1137,15 @@ class SparsePageAdapterBatch {
public: public:
struct Line { struct Line {
SparsePage::Inst inst; Entry const* inst;
size_t n;
bst_row_t ridx; bst_row_t ridx;
COOTuple GetElement(size_t idx) const { COOTuple GetElement(size_t idx) const { return {ridx, inst[idx].index, inst[idx].fvalue}; }
return COOTuple{ridx, inst.data()[idx].index, inst.data()[idx].fvalue}; size_t Size() const { return n; }
}
size_t Size() const { return inst.size(); }
}; };
explicit SparsePageAdapterBatch(HostSparsePageView page) : page_{std::move(page)} {} explicit SparsePageAdapterBatch(HostSparsePageView page) : page_{std::move(page)} {}
Line GetLine(size_t ridx) const { return Line{page_[ridx], ridx}; } Line GetLine(size_t ridx) const { return Line{page_[ridx].data(), page_[ridx].size(), ridx}; }
size_t Size() const { return page_.Size(); } size_t Size() const { return page_.Size(); }
}; };
}; // namespace data }; // namespace data

View File

@ -92,9 +92,8 @@ class CudfAdapterBatch : public detail::NoMetaInfo {
*/ */
class CudfAdapter : public detail::SingleBatchDataIter<CudfAdapterBatch> { class CudfAdapter : public detail::SingleBatchDataIter<CudfAdapterBatch> {
public: public:
explicit CudfAdapter(std::string cuda_interfaces_str) { explicit CudfAdapter(StringView cuda_interfaces_str) {
Json interfaces = Json interfaces = Json::Load(cuda_interfaces_str);
Json::Load({cuda_interfaces_str.c_str(), cuda_interfaces_str.size()});
std::vector<Json> const& json_columns = get<Array>(interfaces); std::vector<Json> const& json_columns = get<Array>(interfaces);
size_t n_columns = json_columns.size(); size_t n_columns = json_columns.size();
CHECK_GT(n_columns, 0) << "Number of columns must not equal to 0."; CHECK_GT(n_columns, 0) << "Number of columns must not equal to 0.";
@ -123,6 +122,9 @@ class CudfAdapter : public detail::SingleBatchDataIter<CudfAdapterBatch> {
columns_ = columns; columns_ = columns;
batch_ = CudfAdapterBatch(dh::ToSpan(columns_), num_rows_); batch_ = CudfAdapterBatch(dh::ToSpan(columns_), num_rows_);
} }
explicit CudfAdapter(std::string cuda_interfaces_str)
: CudfAdapter{StringView{cuda_interfaces_str}} {}
const CudfAdapterBatch& Value() const override { const CudfAdapterBatch& Value() const override {
CHECK_EQ(batch_.columns_.data(), columns_.data().get()); CHECK_EQ(batch_.columns_.data(), columns_.data().get());
return batch_; return batch_;
@ -163,9 +165,8 @@ class CupyAdapterBatch : public detail::NoMetaInfo {
class CupyAdapter : public detail::SingleBatchDataIter<CupyAdapterBatch> { class CupyAdapter : public detail::SingleBatchDataIter<CupyAdapterBatch> {
public: public:
explicit CupyAdapter(std::string cuda_interface_str) { explicit CupyAdapter(StringView cuda_interface_str) {
Json json_array_interface = Json json_array_interface = Json::Load(cuda_interface_str);
Json::Load({cuda_interface_str.c_str(), cuda_interface_str.size()});
array_interface_ = ArrayInterface<2>(get<Object const>(json_array_interface)); array_interface_ = ArrayInterface<2>(get<Object const>(json_array_interface));
batch_ = CupyAdapterBatch(array_interface_); batch_ = CupyAdapterBatch(array_interface_);
if (array_interface_.Shape(0) == 0) { if (array_interface_.Shape(0) == 0) {
@ -174,6 +175,8 @@ class CupyAdapter : public detail::SingleBatchDataIter<CupyAdapterBatch> {
device_idx_ = dh::CudaGetPointerDevice(array_interface_.data); device_idx_ = dh::CudaGetPointerDevice(array_interface_.data);
CHECK_NE(device_idx_, -1); CHECK_NE(device_idx_, -1);
} }
explicit CupyAdapter(std::string cuda_interface_str)
: CupyAdapter{StringView{cuda_interface_str}} {}
const CupyAdapterBatch& Value() const override { return batch_; } const CupyAdapterBatch& Value() const override { return batch_; }
size_t NumRows() const { return array_interface_.Shape(0); } size_t NumRows() const { return array_interface_.Shape(0); }

View File

@ -7,8 +7,8 @@
namespace xgboost { namespace xgboost {
namespace data { namespace data {
void DMatrixProxy::FromCudaColumnar(std::string interface_str) { void DMatrixProxy::FromCudaColumnar(StringView interface_str) {
std::shared_ptr<data::CudfAdapter> adapter {new data::CudfAdapter(interface_str)}; std::shared_ptr<data::CudfAdapter> adapter{new CudfAdapter{interface_str}};
auto const& value = adapter->Value(); auto const& value = adapter->Value();
this->batch_ = adapter; this->batch_ = adapter;
ctx_.gpu_id = adapter->DeviceIdx(); ctx_.gpu_id = adapter->DeviceIdx();
@ -19,8 +19,8 @@ void DMatrixProxy::FromCudaColumnar(std::string interface_str) {
} }
} }
void DMatrixProxy::FromCudaArray(std::string interface_str) { void DMatrixProxy::FromCudaArray(StringView interface_str) {
std::shared_ptr<CupyAdapter> adapter(new CupyAdapter(interface_str)); std::shared_ptr<CupyAdapter> adapter(new CupyAdapter{StringView{interface_str}});
this->batch_ = adapter; this->batch_ = adapter;
ctx_.gpu_id = adapter->DeviceIdx(); ctx_.gpu_id = adapter->DeviceIdx();
this->Info().num_col_ = adapter->NumColumns(); this->Info().num_col_ = adapter->NumColumns();

View File

@ -48,8 +48,8 @@ class DMatrixProxy : public DMatrix {
Context ctx_; Context ctx_;
#if defined(XGBOOST_USE_CUDA) #if defined(XGBOOST_USE_CUDA)
void FromCudaColumnar(std::string interface_str); void FromCudaColumnar(StringView interface_str);
void FromCudaArray(std::string interface_str); void FromCudaArray(StringView interface_str);
#endif // defined(XGBOOST_USE_CUDA) #endif // defined(XGBOOST_USE_CUDA)
public: public:
@ -58,9 +58,8 @@ class DMatrixProxy : public DMatrix {
void SetCUDAArray(char const* c_interface) { void SetCUDAArray(char const* c_interface) {
common::AssertGPUSupport(); common::AssertGPUSupport();
#if defined(XGBOOST_USE_CUDA) #if defined(XGBOOST_USE_CUDA)
std::string interface_str = c_interface; StringView interface_str{c_interface};
Json json_array_interface = Json json_array_interface = Json::Load(interface_str);
Json::Load({interface_str.c_str(), interface_str.size()});
if (IsA<Array>(json_array_interface)) { if (IsA<Array>(json_array_interface)) {
this->FromCudaColumnar(interface_str); this->FromCudaColumnar(interface_str);
} else { } else {
@ -114,10 +113,11 @@ class DMatrixProxy : public DMatrix {
} }
}; };
inline DMatrixProxy *MakeProxy(DMatrixHandle proxy) { inline DMatrixProxy* MakeProxy(DMatrixHandle proxy) {
auto proxy_handle = static_cast<std::shared_ptr<DMatrix> *>(proxy); auto proxy_handle = static_cast<std::shared_ptr<DMatrix>*>(proxy);
CHECK(proxy_handle) << "Invalid proxy handle."; CHECK(proxy_handle) << "Invalid proxy handle.";
DMatrixProxy *typed = static_cast<DMatrixProxy *>(proxy_handle->get()); DMatrixProxy* typed = static_cast<DMatrixProxy*>(proxy_handle->get());
CHECK(typed) << "Invalid proxy handle.";
return typed; return typed;
} }

View File

@ -82,8 +82,8 @@ void TestDistributedQuantile(size_t rows, size_t cols) {
std::vector<float> hessian(rows, 1.0); std::vector<float> hessian(rows, 1.0);
auto hess = Span<float const>{hessian}; auto hess = Span<float const>{hessian};
ContainerType<use_column> sketch_distributed(n_bins, m->Info(), column_size, false, ContainerType<use_column> sketch_distributed(n_bins, m->Info().feature_types.ConstHostSpan(),
OmpGetNumThreads(0)); column_size, false, OmpGetNumThreads(0));
if (use_column) { if (use_column) {
for (auto const& page : m->GetBatches<SortedCSCPage>()) { for (auto const& page : m->GetBatches<SortedCSCPage>()) {
@ -103,8 +103,8 @@ void TestDistributedQuantile(size_t rows, size_t cols) {
CHECK_EQ(rabit::GetWorldSize(), 1); CHECK_EQ(rabit::GetWorldSize(), 1);
std::for_each(column_size.begin(), column_size.end(), [=](auto& size) { size *= world; }); std::for_each(column_size.begin(), column_size.end(), [=](auto& size) { size *= world; });
m->Info().num_row_ = world * rows; m->Info().num_row_ = world * rows;
ContainerType<use_column> sketch_on_single_node(n_bins, m->Info(), column_size, false, ContainerType<use_column> sketch_on_single_node(n_bins, m->Info().feature_types.ConstHostSpan(),
OmpGetNumThreads(0)); column_size, false, OmpGetNumThreads(0));
m->Info().num_row_ = rows; m->Info().num_row_ = rows;
for (auto rank = 0; rank < world; ++rank) { for (auto rank = 0; rank < world; ++rank) {