Implement sketching with Hessian on GPU. (#9399)

- Prepare for implementing approx on GPU.
- Unify the code path between weighted and uniform sketching on DMatrix.
This commit is contained in:
Jiaming Yuan
2023-07-24 15:43:03 +08:00
committed by GitHub
parent 851cba931e
commit a196443a07
14 changed files with 446 additions and 230 deletions

View File

@@ -185,10 +185,10 @@ class MetaInfo {
return data_split_mode == DataSplitMode::kRow;
}
/*! \brief Whether the data is split column-wise. */
bool IsColumnSplit() const {
return data_split_mode == DataSplitMode::kCol;
}
/** @brief Whether the data is split column-wise. */
bool IsColumnSplit() const { return data_split_mode == DataSplitMode::kCol; }
/** @brief Whether this is a learning to rank data. */
bool IsRanking() const { return !group_ptr_.empty(); }
/*!
* \brief A convenient method to check if we are doing vertical federated learning, which requires
@@ -249,7 +249,7 @@ struct BatchParam {
/**
* \brief Hessian, used for sketching with future approx implementation.
*/
common::Span<float> hess;
common::Span<float const> hess;
/**
* \brief Whether should we force DMatrix to regenerate the batch. Only used for
* GHistIndex.
@@ -279,7 +279,7 @@ struct BatchParam {
* Get batch with sketch weighted by hessian. The batch will be regenerated if the
* span is changed, so caller should keep the span for each iteration.
*/
BatchParam(bst_bin_t max_bin, common::Span<float> hessian, bool regenerate)
BatchParam(bst_bin_t max_bin, common::Span<float const> hessian, bool regenerate)
: max_bin{max_bin}, hess{hessian}, regen{regenerate} {}
[[nodiscard]] bool ParamNotEqual(BatchParam const& other) const {

View File

@@ -49,11 +49,12 @@
#ifndef XGBOOST_HOST_DEVICE_VECTOR_H_
#define XGBOOST_HOST_DEVICE_VECTOR_H_
#include <initializer_list>
#include <vector>
#include <type_traits>
#include <xgboost/context.h> // for DeviceOrd
#include <xgboost/span.h> // for Span
#include "span.h"
#include <initializer_list>
#include <type_traits>
#include <vector>
namespace xgboost {
@@ -133,6 +134,7 @@ class HostDeviceVector {
GPUAccess DeviceAccess() const;
void SetDevice(int device) const;
void SetDevice(DeviceOrd device) const;
void Resize(size_t new_size, T v = T());