Use matrix for gradient. (#9508)

- Use the `linalg::Matrix` for storing gradients.
- New API for the custom objective.
- Custom objective for multi-class/multi-target is now required to return the correct shape.
- Custom objective for Python can accept arrays with any strides. (row-major, column-major)
This commit is contained in:
Jiaming Yuan 2023-08-24 05:29:52 +08:00 committed by GitHub
parent 6103dca0bb
commit 972730cde0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
77 changed files with 1052 additions and 651 deletions

View File

@ -154,7 +154,14 @@ xgb.iter.update <- function(booster_handle, dtrain, iter, obj) {
pred <- predict(booster_handle, dtrain, outputmargin = TRUE, training = TRUE, pred <- predict(booster_handle, dtrain, outputmargin = TRUE, training = TRUE,
ntreelimit = 0) ntreelimit = 0)
gpair <- obj(pred, dtrain) gpair <- obj(pred, dtrain)
.Call(XGBoosterBoostOneIter_R, booster_handle, dtrain, gpair$grad, gpair$hess) n_samples <- dim(dtrain)[1]
# We still require row-major in R as I'm not quite sure sure how to get the stride of
# the matrix in C.
gpair$grad <- matrix(gpair$grad, nrow = n_samples, byrow = TRUE)
gpair$hess <- matrix(gpair$hess, nrow = n_samples, byrow = TRUE)
.Call(
XGBoosterBoostOneIter_R, booster_handle, dtrain, iter, gpair$grad, gpair$hess
)
} }
return(TRUE) return(TRUE)
} }

View File

@ -16,7 +16,7 @@ Check these declarations against the C/Fortran source code.
*/ */
/* .Call calls */ /* .Call calls */
extern SEXP XGBoosterBoostOneIter_R(SEXP, SEXP, SEXP, SEXP); extern SEXP XGBoosterTrainOneIter_R(SEXP, SEXP, SEXP, SEXP, SEXP);
extern SEXP XGBoosterCreate_R(SEXP); extern SEXP XGBoosterCreate_R(SEXP);
extern SEXP XGBoosterCreateInEmptyObj_R(SEXP, SEXP); extern SEXP XGBoosterCreateInEmptyObj_R(SEXP, SEXP);
extern SEXP XGBoosterDumpModel_R(SEXP, SEXP, SEXP, SEXP); extern SEXP XGBoosterDumpModel_R(SEXP, SEXP, SEXP, SEXP);
@ -53,7 +53,7 @@ extern SEXP XGBGetGlobalConfig_R(void);
extern SEXP XGBoosterFeatureScore_R(SEXP, SEXP); extern SEXP XGBoosterFeatureScore_R(SEXP, SEXP);
static const R_CallMethodDef CallEntries[] = { static const R_CallMethodDef CallEntries[] = {
{"XGBoosterBoostOneIter_R", (DL_FUNC) &XGBoosterBoostOneIter_R, 4}, {"XGBoosterBoostOneIter_R", (DL_FUNC) &XGBoosterTrainOneIter_R, 5},
{"XGBoosterCreate_R", (DL_FUNC) &XGBoosterCreate_R, 1}, {"XGBoosterCreate_R", (DL_FUNC) &XGBoosterCreate_R, 1},
{"XGBoosterCreateInEmptyObj_R", (DL_FUNC) &XGBoosterCreateInEmptyObj_R, 2}, {"XGBoosterCreateInEmptyObj_R", (DL_FUNC) &XGBoosterCreateInEmptyObj_R, 2},
{"XGBoosterDumpModel_R", (DL_FUNC) &XGBoosterDumpModel_R, 4}, {"XGBoosterDumpModel_R", (DL_FUNC) &XGBoosterDumpModel_R, 4},

View File

@ -48,13 +48,6 @@
using dmlc::BeginPtr; using dmlc::BeginPtr;
xgboost::Context const *BoosterCtx(BoosterHandle handle) {
CHECK_HANDLE();
auto *learner = static_cast<xgboost::Learner *>(handle);
CHECK(learner);
return learner->Ctx();
}
xgboost::Context const *DMatrixCtx(DMatrixHandle handle) { xgboost::Context const *DMatrixCtx(DMatrixHandle handle) {
CHECK_HANDLE(); CHECK_HANDLE();
auto p_m = static_cast<std::shared_ptr<xgboost::DMatrix> *>(handle); auto p_m = static_cast<std::shared_ptr<xgboost::DMatrix> *>(handle);
@ -394,21 +387,25 @@ XGB_DLL SEXP XGBoosterUpdateOneIter_R(SEXP handle, SEXP iter, SEXP dtrain) {
return R_NilValue; return R_NilValue;
} }
XGB_DLL SEXP XGBoosterBoostOneIter_R(SEXP handle, SEXP dtrain, SEXP grad, SEXP hess) { XGB_DLL SEXP XGBoosterTrainOneIter_R(SEXP handle, SEXP dtrain, SEXP iter, SEXP grad, SEXP hess) {
R_API_BEGIN(); R_API_BEGIN();
CHECK_EQ(length(grad), length(hess)) CHECK_EQ(length(grad), length(hess)) << "gradient and hess must have same length";
<< "gradient and hess must have same length"; SEXP gdim = getAttrib(grad, R_DimSymbol);
int len = length(grad); auto n_samples = static_cast<std::size_t>(INTEGER(gdim)[0]);
std::vector<float> tgrad(len), thess(len); auto n_targets = static_cast<std::size_t>(INTEGER(gdim)[1]);
auto ctx = BoosterCtx(R_ExternalPtrAddr(handle));
xgboost::common::ParallelFor(len, ctx->Threads(), [&](xgboost::omp_ulong j) { SEXP hdim = getAttrib(hess, R_DimSymbol);
tgrad[j] = REAL(grad)[j]; CHECK_EQ(INTEGER(hdim)[0], n_samples) << "mismatched size between gradient and hessian";
thess[j] = REAL(hess)[j]; CHECK_EQ(INTEGER(hdim)[1], n_targets) << "mismatched size between gradient and hessian";
}); double const *d_grad = REAL(grad);
CHECK_CALL(XGBoosterBoostOneIter(R_ExternalPtrAddr(handle), double const *d_hess = REAL(hess);
R_ExternalPtrAddr(dtrain),
BeginPtr(tgrad), BeginPtr(thess), auto ctx = xgboost::detail::BoosterCtx(R_ExternalPtrAddr(handle));
len)); auto [s_grad, s_hess] =
xgboost::detail::MakeGradientInterface(ctx, d_grad, d_hess, n_samples, n_targets);
CHECK_CALL(XGBoosterTrainOneIter(R_ExternalPtrAddr(handle), R_ExternalPtrAddr(dtrain),
asInteger(iter), s_grad.c_str(), s_hess.c_str()));
R_API_END(); R_API_END();
return R_NilValue; return R_NilValue;
} }
@ -460,7 +457,7 @@ XGB_DLL SEXP XGBoosterPredictFromDMatrix_R(SEXP handle, SEXP dmat, SEXP json_con
len *= out_shape[i]; len *= out_shape[i];
} }
r_out_result = PROTECT(allocVector(REALSXP, len)); r_out_result = PROTECT(allocVector(REALSXP, len));
auto ctx = BoosterCtx(R_ExternalPtrAddr(handle)); auto ctx = xgboost::detail::BoosterCtx(R_ExternalPtrAddr(handle));
xgboost::common::ParallelFor(len, ctx->Threads(), [&](xgboost::omp_ulong i) { xgboost::common::ParallelFor(len, ctx->Threads(), [&](xgboost::omp_ulong i) {
REAL(r_out_result)[i] = out_result[i]; REAL(r_out_result)[i] = out_result[i];
}); });
@ -669,7 +666,7 @@ XGB_DLL SEXP XGBoosterFeatureScore_R(SEXP handle, SEXP json_config) {
} }
out_scores_sexp = PROTECT(allocVector(REALSXP, len)); out_scores_sexp = PROTECT(allocVector(REALSXP, len));
auto ctx = BoosterCtx(R_ExternalPtrAddr(handle)); auto ctx = xgboost::detail::BoosterCtx(R_ExternalPtrAddr(handle));
xgboost::common::ParallelFor(len, ctx->Threads(), [&](xgboost::omp_ulong i) { xgboost::common::ParallelFor(len, ctx->Threads(), [&](xgboost::omp_ulong i) {
REAL(out_scores_sexp)[i] = out_scores[i]; REAL(out_scores_sexp)[i] = out_scores[i];
}); });

View File

@ -161,12 +161,13 @@ XGB_DLL SEXP XGBoosterUpdateOneIter_R(SEXP ext, SEXP iter, SEXP dtrain);
* \brief update the model, by directly specify gradient and second order gradient, * \brief update the model, by directly specify gradient and second order gradient,
* this can be used to replace UpdateOneIter, to support customized loss function * this can be used to replace UpdateOneIter, to support customized loss function
* \param handle handle * \param handle handle
* \param iter The current training iteration.
* \param dtrain training data * \param dtrain training data
* \param grad gradient statistics * \param grad gradient statistics
* \param hess second order gradient statistics * \param hess second order gradient statistics
* \return R_NilValue * \return R_NilValue
*/ */
XGB_DLL SEXP XGBoosterBoostOneIter_R(SEXP handle, SEXP dtrain, SEXP grad, SEXP hess); XGB_DLL SEXP XGBoosterTrainOneIter_R(SEXP handle, SEXP dtrain, SEXP iter, SEXP grad, SEXP hess);
/*! /*!
* \brief get evaluation statistics for xgboost * \brief get evaluation statistics for xgboost

View File

@ -76,9 +76,7 @@ def softprob_obj(predt: np.ndarray, data: xgb.DMatrix):
grad[r, c] = g grad[r, c] = g
hess[r, c] = h hess[r, c] = h
# Right now (XGBoost 1.0.0), reshaping is necessary # After 2.1.0, pass the gradient as it is.
grad = grad.reshape((kRows * kClasses, 1))
hess = hess.reshape((kRows * kClasses, 1))
return grad, hess return grad, hess

View File

@ -68,22 +68,21 @@ def rmse_model(plot_result: bool, strategy: str) -> None:
def custom_rmse_model(plot_result: bool, strategy: str) -> None: def custom_rmse_model(plot_result: bool, strategy: str) -> None:
"""Train using Python implementation of Squared Error.""" """Train using Python implementation of Squared Error."""
# As the experimental support status, custom objective doesn't support matrix as
# gradient and hessian, which will be changed in future release.
def gradient(predt: np.ndarray, dtrain: xgb.DMatrix) -> np.ndarray: def gradient(predt: np.ndarray, dtrain: xgb.DMatrix) -> np.ndarray:
"""Compute the gradient squared error.""" """Compute the gradient squared error."""
y = dtrain.get_label().reshape(predt.shape) y = dtrain.get_label().reshape(predt.shape)
return (predt - y).reshape(y.size) return predt - y
def hessian(predt: np.ndarray, dtrain: xgb.DMatrix) -> np.ndarray: def hessian(predt: np.ndarray, dtrain: xgb.DMatrix) -> np.ndarray:
"""Compute the hessian for squared error.""" """Compute the hessian for squared error."""
return np.ones(predt.shape).reshape(predt.size) return np.ones(predt.shape)
def squared_log( def squared_log(
predt: np.ndarray, dtrain: xgb.DMatrix predt: np.ndarray, dtrain: xgb.DMatrix
) -> Tuple[np.ndarray, np.ndarray]: ) -> Tuple[np.ndarray, np.ndarray]:
grad = gradient(predt, dtrain) grad = gradient(predt, dtrain)
hess = hessian(predt, dtrain) hess = hessian(predt, dtrain)
# both numpy.ndarray and cupy.ndarray works.
return grad, hess return grad, hess
def rmse(predt: np.ndarray, dtrain: xgb.DMatrix) -> Tuple[str, float]: def rmse(predt: np.ndarray, dtrain: xgb.DMatrix) -> Tuple[str, float]:

View File

@ -274,8 +274,8 @@ class GradientPairInt64 {
GradientPairInt64(GradientPairInt64 const &g) = default; GradientPairInt64(GradientPairInt64 const &g) = default;
GradientPairInt64 &operator=(GradientPairInt64 const &g) = default; GradientPairInt64 &operator=(GradientPairInt64 const &g) = default;
XGBOOST_DEVICE [[nodiscard]] T GetQuantisedGrad() const { return grad_; } [[nodiscard]] XGBOOST_DEVICE T GetQuantisedGrad() const { return grad_; }
XGBOOST_DEVICE [[nodiscard]] T GetQuantisedHess() const { return hess_; } [[nodiscard]] XGBOOST_DEVICE T GetQuantisedHess() const { return hess_; }
XGBOOST_DEVICE GradientPairInt64 &operator+=(const GradientPairInt64 &rhs) { XGBOOST_DEVICE GradientPairInt64 &operator+=(const GradientPairInt64 &rhs) {
grad_ += rhs.grad_; grad_ += rhs.grad_;

View File

@ -789,16 +789,14 @@ XGB_DLL int XGDMatrixGetUIntInfo(const DMatrixHandle handle,
* \param out The address to hold number of rows. * \param out The address to hold number of rows.
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
XGB_DLL int XGDMatrixNumRow(DMatrixHandle handle, XGB_DLL int XGDMatrixNumRow(DMatrixHandle handle, bst_ulong *out);
bst_ulong *out);
/*! /*!
* \brief get number of columns * \brief get number of columns
* \param handle the handle to the DMatrix * \param handle the handle to the DMatrix
* \param out The output of number of columns * \param out The output of number of columns
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
XGB_DLL int XGDMatrixNumCol(DMatrixHandle handle, XGB_DLL int XGDMatrixNumCol(DMatrixHandle handle, bst_ulong *out);
bst_ulong *out);
/*! /*!
* \brief Get number of valid values from DMatrix. * \brief Get number of valid values from DMatrix.
@ -945,21 +943,30 @@ XGB_DLL int XGBoosterUpdateOneIter(BoosterHandle handle, int iter, DMatrixHandle
* @example c-api-demo.c * @example c-api-demo.c
*/ */
/*! /**
* \brief update the model, by directly specify gradient and second order gradient, * @deprecated since 2.1.0
* this can be used to replace UpdateOneIter, to support customized loss function
* \param handle handle
* \param dtrain training data
* \param grad gradient statistics
* \param hess second order gradient statistics
* \param len length of grad/hess array
* \return 0 when success, -1 when failure happens
*/ */
XGB_DLL int XGBoosterBoostOneIter(BoosterHandle handle, XGB_DLL int XGBoosterBoostOneIter(BoosterHandle handle, DMatrixHandle dtrain, float *grad,
DMatrixHandle dtrain, float *hess, bst_ulong len);
float *grad,
float *hess, /**
bst_ulong len); * @brief Update a model with gradient and Hessian. This is used for training with a
* custom objective function.
*
* @since 2.0.0
*
* @param handle handle
* @param dtrain The training data.
* @param iter The current iteration round. When training continuation is used, the count
* should restart.
* @param grad Json encoded __(cuda)_array_interface__ for gradient.
* @param hess Json encoded __(cuda)_array_interface__ for Hessian.
*
* @return 0 when success, -1 when failure happens
*/
XGB_DLL int XGBoosterTrainOneIter(BoosterHandle handle, DMatrixHandle dtrain, int iter,
char const *grad, char const *hess);
/*! /*!
* \brief get evaluation statistics for xgboost * \brief get evaluation statistics for xgboost
* \param handle handle * \param handle handle

View File

@ -70,22 +70,25 @@ class GradientBooster : public Model, public Configurable {
GradientBooster* /*out*/, bool* /*out_of_bound*/) const { GradientBooster* /*out*/, bool* /*out_of_bound*/) const {
LOG(FATAL) << "Slice is not supported by the current booster."; LOG(FATAL) << "Slice is not supported by the current booster.";
} }
/*! \brief Return number of boosted rounds. /**
* @brief Return number of boosted rounds.
*/ */
virtual int32_t BoostedRounds() const = 0; [[nodiscard]] virtual std::int32_t BoostedRounds() const = 0;
/** /**
* \brief Whether the model has already been trained. When tree booster is chosen, then * \brief Whether the model has already been trained. When tree booster is chosen, then
* returns true when there are existing trees. * returns true when there are existing trees.
*/ */
virtual bool ModelFitted() const = 0; [[nodiscard]] virtual bool ModelFitted() const = 0;
/*! /**
* \brief perform update to the model(boosting) * @brief perform update to the model(boosting)
* \param p_fmat feature matrix that provide access to features *
* \param in_gpair address of the gradient pair statistics of the data * @param p_fmat feature matrix that provide access to features
* \param prediction The output prediction cache entry that needs to be updated. * @param in_gpair address of the gradient pair statistics of the data
* @param prediction The output prediction cache entry that needs to be updated.
* the booster may change content of gpair * the booster may change content of gpair
* @param obj The objective function used for boosting.
*/ */
virtual void DoBoost(DMatrix* p_fmat, HostDeviceVector<GradientPair>* in_gpair, virtual void DoBoost(DMatrix* p_fmat, linalg::Matrix<GradientPair>* in_gpair,
PredictionCacheEntry*, ObjFunction const* obj) = 0; PredictionCacheEntry*, ObjFunction const* obj) = 0;
/** /**
@ -165,18 +168,17 @@ class GradientBooster : public Model, public Configurable {
* \param format the format to dump the model in * \param format the format to dump the model in
* \return a vector of dump for boosters. * \return a vector of dump for boosters.
*/ */
virtual std::vector<std::string> DumpModel(const FeatureMap& fmap, [[nodiscard]] virtual std::vector<std::string> DumpModel(const FeatureMap& fmap, bool with_stats,
bool with_stats,
std::string format) const = 0; std::string format) const = 0;
virtual void FeatureScore(std::string const& importance_type, virtual void FeatureScore(std::string const& importance_type,
common::Span<int32_t const> trees, common::Span<int32_t const> trees,
std::vector<bst_feature_t>* features, std::vector<bst_feature_t>* features,
std::vector<float>* scores) const = 0; std::vector<float>* scores) const = 0;
/*! /**
* \brief Whether the current booster uses GPU. * @brief Whether the current booster uses GPU.
*/ */
virtual bool UseGPU() const = 0; [[nodiscard]] virtual bool UseGPU() const = 0;
/*! /*!
* \brief create a gradient booster from given name * \brief create a gradient booster from given name
* \param name name of gradient booster * \param name name of gradient booster

View File

@ -76,17 +76,18 @@ class Learner : public Model, public Configurable, public dmlc::Serializable {
* \param iter current iteration number * \param iter current iteration number
* \param train reference to the data matrix. * \param train reference to the data matrix.
*/ */
virtual void UpdateOneIter(int iter, std::shared_ptr<DMatrix> train) = 0; virtual void UpdateOneIter(std::int32_t iter, std::shared_ptr<DMatrix> train) = 0;
/*! /**
* \brief Do customized gradient boosting with in_gpair. * @brief Do customized gradient boosting with in_gpair.
* in_gair can be mutated after this call. *
* \param iter current iteration number * @note in_gpair can be mutated after this call.
* \param train reference to the data matrix. *
* \param in_gpair The input gradient statistics. * @param iter current iteration number
* @param train reference to the data matrix.
* @param in_gpair The input gradient statistics.
*/ */
virtual void BoostOneIter(int iter, virtual void BoostOneIter(std::int32_t iter, std::shared_ptr<DMatrix> train,
std::shared_ptr<DMatrix> train, linalg::Matrix<GradientPair>* in_gpair) = 0;
HostDeviceVector<GradientPair>* in_gpair) = 0;
/*! /*!
* \brief evaluate the model for specific iteration using the configured metrics. * \brief evaluate the model for specific iteration using the configured metrics.
* \param iter iteration number * \param iter iteration number

View File

@ -292,7 +292,7 @@ enum Order : std::uint8_t {
template <typename T, int32_t kDim> template <typename T, int32_t kDim>
class TensorView { class TensorView {
public: public:
using ShapeT = size_t[kDim]; using ShapeT = std::size_t[kDim];
using StrideT = ShapeT; using StrideT = ShapeT;
private: private:
@ -400,10 +400,14 @@ class TensorView {
* \param shape shape of the tensor * \param shape shape of the tensor
* \param device Device ordinal * \param device Device ordinal
*/ */
template <typename I, int32_t D> template <typename I, std::int32_t D>
LINALG_HD TensorView(common::Span<T> data, I const (&shape)[D], std::int32_t device) LINALG_HD TensorView(common::Span<T> data, I const (&shape)[D], std::int32_t device)
: TensorView{data, shape, device, Order::kC} {} : TensorView{data, shape, device, Order::kC} {}
template <typename I, std::int32_t D>
LINALG_HD TensorView(common::Span<T> data, I const (&shape)[D], DeviceOrd device)
: TensorView{data, shape, device.ordinal, Order::kC} {}
template <typename I, int32_t D> template <typename I, int32_t D>
LINALG_HD TensorView(common::Span<T> data, I const (&shape)[D], std::int32_t device, Order order) LINALG_HD TensorView(common::Span<T> data, I const (&shape)[D], std::int32_t device, Order order)
: data_{data}, ptr_{data_.data()}, device_{device} { : data_{data}, ptr_{data_.data()}, device_{device} {
@ -446,6 +450,10 @@ class TensorView {
}); });
this->CalcSize(); this->CalcSize();
} }
template <typename I, std::int32_t D>
LINALG_HD TensorView(common::Span<T> data, I const (&shape)[D], I const (&stride)[D],
DeviceOrd device)
: TensorView{data, shape, stride, device.ordinal} {}
template < template <
typename U, typename U,
@ -741,7 +749,7 @@ auto ArrayInterfaceStr(TensorView<T, D> const &t) {
template <typename T, int32_t kDim = 5> template <typename T, int32_t kDim = 5>
class Tensor { class Tensor {
public: public:
using ShapeT = size_t[kDim]; using ShapeT = std::size_t[kDim];
using StrideT = ShapeT; using StrideT = ShapeT;
private: private:
@ -775,6 +783,9 @@ class Tensor {
template <typename I, int32_t D> template <typename I, int32_t D>
explicit Tensor(I const (&shape)[D], std::int32_t device, Order order = kC) explicit Tensor(I const (&shape)[D], std::int32_t device, Order order = kC)
: Tensor{common::Span<I const, D>{shape}, device, order} {} : Tensor{common::Span<I const, D>{shape}, device, order} {}
template <typename I, int32_t D>
explicit Tensor(I const (&shape)[D], DeviceOrd device, Order order = kC)
: Tensor{common::Span<I const, D>{shape}, device.ordinal, order} {}
template <typename I, size_t D> template <typename I, size_t D>
explicit Tensor(common::Span<I const, D> shape, std::int32_t device, Order order = kC) explicit Tensor(common::Span<I const, D> shape, std::int32_t device, Order order = kC)
@ -814,6 +825,10 @@ class Tensor {
// shape // shape
this->Initialize(shape, device); this->Initialize(shape, device);
} }
template <typename I, int32_t D>
explicit Tensor(std::initializer_list<T> data, I const (&shape)[D], DeviceOrd device,
Order order = kC)
: Tensor{data, shape, device.ordinal, order} {}
/** /**
* \brief Index operator. Not thread safe, should not be used in performance critical * \brief Index operator. Not thread safe, should not be used in performance critical
* region. For more efficient indexing, consider getting a view first. * region. For more efficient indexing, consider getting a view first.
@ -832,9 +847,9 @@ class Tensor {
} }
/** /**
* \brief Get a \ref TensorView for this tensor. * @brief Get a @ref TensorView for this tensor.
*/ */
TensorView<T, kDim> View(int32_t device) { TensorView<T, kDim> View(std::int32_t device) {
if (device >= 0) { if (device >= 0) {
data_.SetDevice(device); data_.SetDevice(device);
auto span = data_.DeviceSpan(); auto span = data_.DeviceSpan();
@ -844,7 +859,7 @@ class Tensor {
return {span, shape_, device, order_}; return {span, shape_, device, order_};
} }
} }
TensorView<T const, kDim> View(int32_t device) const { TensorView<T const, kDim> View(std::int32_t device) const {
if (device >= 0) { if (device >= 0) {
data_.SetDevice(device); data_.SetDevice(device);
auto span = data_.ConstDeviceSpan(); auto span = data_.ConstDeviceSpan();
@ -854,6 +869,26 @@ class Tensor {
return {span, shape_, device, order_}; return {span, shape_, device, order_};
} }
} }
auto View(DeviceOrd device) {
if (device.IsCUDA()) {
data_.SetDevice(device);
auto span = data_.DeviceSpan();
return TensorView<T, kDim>{span, shape_, device.ordinal, order_};
} else {
auto span = data_.HostSpan();
return TensorView<T, kDim>{span, shape_, device.ordinal, order_};
}
}
auto View(DeviceOrd device) const {
if (device.IsCUDA()) {
data_.SetDevice(device);
auto span = data_.ConstDeviceSpan();
return TensorView<T const, kDim>{span, shape_, device.ordinal, order_};
} else {
auto span = data_.ConstHostSpan();
return TensorView<T const, kDim>{span, shape_, device.ordinal, order_};
}
}
auto HostView() const { return this->View(-1); } auto HostView() const { return this->View(-1); }
auto HostView() { return this->View(-1); } auto HostView() { return this->View(-1); }
@ -931,6 +966,7 @@ class Tensor {
* \brief Set device ordinal for this tensor. * \brief Set device ordinal for this tensor.
*/ */
void SetDevice(int32_t device) const { data_.SetDevice(device); } void SetDevice(int32_t device) const { data_.SetDevice(device); }
void SetDevice(DeviceOrd device) const { data_.SetDevice(device); }
[[nodiscard]] int32_t DeviceIdx() const { return data_.DeviceIdx(); } [[nodiscard]] int32_t DeviceIdx() const { return data_.DeviceIdx(); }
}; };

View File

@ -49,9 +49,8 @@ class LinearUpdater : public Configurable {
* \param model Model to be updated. * \param model Model to be updated.
* \param sum_instance_weight The sum instance weights, used to normalise l1/l2 penalty. * \param sum_instance_weight The sum instance weights, used to normalise l1/l2 penalty.
*/ */
virtual void Update(HostDeviceVector<GradientPair>* in_gpair, DMatrix* data, virtual void Update(linalg::Matrix<GradientPair>* in_gpair, DMatrix* data,
gbm::GBLinearModel* model, gbm::GBLinearModel* model, double sum_instance_weight) = 0;
double sum_instance_weight) = 0;
/*! /*!
* \brief Create a linear updater given name * \brief Create a linear updater given name

View File

@ -41,17 +41,16 @@ class ObjFunction : public Configurable {
* \param args arguments to the objective function. * \param args arguments to the objective function.
*/ */
virtual void Configure(const std::vector<std::pair<std::string, std::string> >& args) = 0; virtual void Configure(const std::vector<std::pair<std::string, std::string> >& args) = 0;
/*! /**
* \brief Get gradient over each of predictions, given existing information. * @brief Get gradient over each of predictions, given existing information.
* \param preds prediction of current round *
* \param info information about labels, weights, groups in rank * @param preds prediction of current round
* \param iteration current iteration number. * @param info information about labels, weights, groups in rank
* \param out_gpair output of get gradient, saves gradient and second order gradient in * @param iteration current iteration number.
* @param out_gpair output of get gradient, saves gradient and second order gradient in
*/ */
virtual void GetGradient(const HostDeviceVector<bst_float>& preds, virtual void GetGradient(const HostDeviceVector<bst_float>& preds, const MetaInfo& info,
const MetaInfo& info, std::int32_t iter, linalg::Matrix<GradientPair>* out_gpair) = 0;
int iteration,
HostDeviceVector<GradientPair>* out_gpair) = 0;
/*! \return the default evaluation metric for the objective */ /*! \return the default evaluation metric for the objective */
virtual const char* DefaultEvalMetric() const = 0; virtual const char* DefaultEvalMetric() const = 0;
@ -81,9 +80,7 @@ class ObjFunction : public Configurable {
* used by gradient boosting * used by gradient boosting
* \return transformed value * \return transformed value
*/ */
virtual bst_float ProbToMargin(bst_float base_score) const { [[nodiscard]] virtual bst_float ProbToMargin(bst_float base_score) const { return base_score; }
return base_score;
}
/** /**
* \brief Make initialize estimation of prediction. * \brief Make initialize estimation of prediction.
* *
@ -94,14 +91,14 @@ class ObjFunction : public Configurable {
/*! /*!
* \brief Return task of this objective. * \brief Return task of this objective.
*/ */
virtual struct ObjInfo Task() const = 0; [[nodiscard]] virtual struct ObjInfo Task() const = 0;
/** /**
* \brief Return number of targets for input matrix. Right now XGBoost supports only * @brief Return number of targets for input matrix. Right now XGBoost supports only
* multi-target regression. * multi-target regression.
*/ */
virtual bst_target_t Targets(MetaInfo const& info) const { [[nodiscard]] virtual bst_target_t Targets(MetaInfo const& info) const {
if (info.labels.Shape(1) > 1) { if (info.labels.Shape(1) > 1) {
LOG(FATAL) << "multioutput is not supported by current objective function"; LOG(FATAL) << "multioutput is not supported by the current objective function";
} }
return 1; return 1;
} }

View File

@ -71,7 +71,7 @@ class TreeUpdater : public Configurable {
* but maybe different random seeds, usually one tree is passed in at a time, * but maybe different random seeds, usually one tree is passed in at a time,
* there can be multiple trees when we train random forest style model * there can be multiple trees when we train random forest style model
*/ */
virtual void Update(tree::TrainParam const* param, HostDeviceVector<GradientPair>* gpair, virtual void Update(tree::TrainParam const* param, linalg::Matrix<GradientPair>* gpair,
DMatrix* data, common::Span<HostDeviceVector<bst_node_t>> out_position, DMatrix* data, common::Span<HostDeviceVector<bst_node_t>> out_position,
const std::vector<RegTree*>& out_trees) = 0; const std::vector<RegTree*>& out_trees) = 0;

View File

@ -218,34 +218,48 @@ public class Booster implements Serializable, KryoSerializable {
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterUpdateOneIter(handle, iter, dtrain.getHandle())); XGBoostJNI.checkCall(XGBoostJNI.XGBoosterUpdateOneIter(handle, iter, dtrain.getHandle()));
} }
@Deprecated
public void update(DMatrix dtrain, IObjective obj) throws XGBoostError {
float[][] predicts = this.predict(dtrain, true, 0, false, false);
List<float[]> gradients = obj.getGradient(predicts, dtrain);
this.boost(dtrain, gradients.get(0), gradients.get(1));
}
/** /**
* Update with customize obj func * Update with customize obj func
* *
* @param dtrain training data * @param dtrain training data
* @param iter The current training iteration.
* @param obj customized objective class * @param obj customized objective class
* @throws XGBoostError native error * @throws XGBoostError native error
*/ */
public void update(DMatrix dtrain, IObjective obj) throws XGBoostError { public void update(DMatrix dtrain, int iter, IObjective obj) throws XGBoostError {
float[][] predicts = this.predict(dtrain, true, 0, false, false); float[][] predicts = this.predict(dtrain, true, 0, false, false);
List<float[]> gradients = obj.getGradient(predicts, dtrain); List<float[]> gradients = obj.getGradient(predicts, dtrain);
boost(dtrain, gradients.get(0), gradients.get(1)); this.boost(dtrain, iter, gradients.get(0), gradients.get(1));
}
@Deprecated
public void boost(DMatrix dtrain, float[] grad, float[] hess) throws XGBoostError {
this.boost(dtrain, 0, grad, hess);
} }
/** /**
* update with give grad and hess * Update with give grad and hess
* *
* @param dtrain training data * @param dtrain training data
* @param iter The current training iteration.
* @param grad first order of gradient * @param grad first order of gradient
* @param hess seconde order of gradient * @param hess seconde order of gradient
* @throws XGBoostError native error * @throws XGBoostError native error
*/ */
public void boost(DMatrix dtrain, float[] grad, float[] hess) throws XGBoostError { public void boost(DMatrix dtrain, int iter, float[] grad, float[] hess) throws XGBoostError {
if (grad.length != hess.length) { if (grad.length != hess.length) {
throw new AssertionError(String.format("grad/hess length mismatch %s / %s", grad.length, throw new AssertionError(String.format("grad/hess length mismatch %s / %s", grad.length,
hess.length)); hess.length));
} }
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterBoostOneIter(handle, XGBoostJNI.checkCall(XGBoostJNI.XGBoosterTrainOneIter(handle,
dtrain.getHandle(), grad, hess)); dtrain.getHandle(), iter, grad, hess));
} }
/** /**

View File

@ -110,7 +110,7 @@ class XGBoostJNI {
public final static native int XGBoosterUpdateOneIter(long handle, int iter, long dtrain); public final static native int XGBoosterUpdateOneIter(long handle, int iter, long dtrain);
public final static native int XGBoosterBoostOneIter(long handle, long dtrain, float[] grad, public final static native int XGBoosterTrainOneIter(long handle, long dtrain, int iter, float[] grad,
float[] hess); float[] hess);
public final static native int XGBoosterEvalOneIter(long handle, int iter, long[] dmats, public final static native int XGBoosterEvalOneIter(long handle, int iter, long[] dmats,

View File

@ -106,27 +106,41 @@ class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster)
booster.update(dtrain.jDMatrix, iter) booster.update(dtrain.jDMatrix, iter)
} }
@throws(classOf[XGBoostError])
@deprecated
def update(dtrain: DMatrix, obj: ObjectiveTrait): Unit = {
booster.update(dtrain.jDMatrix, obj)
}
/** /**
* update with customize obj func * update with customize obj func
* *
* @param dtrain training data * @param dtrain training data
* @param iter The current training iteration
* @param obj customized objective class * @param obj customized objective class
*/ */
@throws(classOf[XGBoostError]) @throws(classOf[XGBoostError])
def update(dtrain: DMatrix, obj: ObjectiveTrait): Unit = { def update(dtrain: DMatrix, iter: Int, obj: ObjectiveTrait): Unit = {
booster.update(dtrain.jDMatrix, obj) booster.update(dtrain.jDMatrix, iter, obj)
}
@throws(classOf[XGBoostError])
@deprecated
def boost(dtrain: DMatrix, grad: Array[Float], hess: Array[Float]): Unit = {
booster.boost(dtrain.jDMatrix, grad, hess)
} }
/** /**
* update with give grad and hess * update with give grad and hess
* *
* @param dtrain training data * @param dtrain training data
* @param iter The current training iteration
* @param grad first order of gradient * @param grad first order of gradient
* @param hess seconde order of gradient * @param hess seconde order of gradient
*/ */
@throws(classOf[XGBoostError]) @throws(classOf[XGBoostError])
def boost(dtrain: DMatrix, grad: Array[Float], hess: Array[Float]): Unit = { def boost(dtrain: DMatrix, iter: Int, grad: Array[Float], hess: Array[Float]): Unit = {
booster.boost(dtrain.jDMatrix, grad, hess) booster.boost(dtrain.jDMatrix, iter, grad, hess)
} }
/** /**

View File

@ -28,6 +28,7 @@
#include <type_traits> #include <type_traits>
#include <vector> #include <vector>
#include "../../../src/c_api/c_api_error.h"
#include "../../../src/c_api/c_api_utils.h" #include "../../../src/c_api/c_api_utils.h"
#define JVM_CHECK_CALL(__expr) \ #define JVM_CHECK_CALL(__expr) \
@ -579,22 +580,44 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterUpdateOne
/* /*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI * Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterBoostOneIter * Method: XGBoosterTrainOneIter
* Signature: (JJ[F[F)V * Signature: (JJI[F[F)I
*/ */
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterBoostOneIter JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterTrainOneIter(
(JNIEnv *jenv, jclass jcls, jlong jhandle, jlong jdtrain, jfloatArray jgrad, jfloatArray jhess) { JNIEnv *jenv, jclass jcls, jlong jhandle, jlong jdtrain, jint jiter, jfloatArray jgrad,
BoosterHandle handle = (BoosterHandle) jhandle; jfloatArray jhess) {
DMatrixHandle dtrain = (DMatrixHandle) jdtrain; API_BEGIN();
jfloat* grad = jenv->GetFloatArrayElements(jgrad, 0); BoosterHandle handle = reinterpret_cast<BoosterHandle *>(jhandle);
jfloat* hess = jenv->GetFloatArrayElements(jhess, 0); DMatrixHandle dtrain = reinterpret_cast<DMatrixHandle *>(jdtrain);
bst_ulong len = (bst_ulong)jenv->GetArrayLength(jgrad); CHECK(handle);
int ret = XGBoosterBoostOneIter(handle, dtrain, grad, hess, len); CHECK(dtrain);
JVM_CHECK_CALL(ret); bst_ulong n_samples{0};
JVM_CHECK_CALL(XGDMatrixNumRow(dtrain, &n_samples));
bst_ulong len = static_cast<bst_ulong>(jenv->GetArrayLength(jgrad));
jfloat *grad = jenv->GetFloatArrayElements(jgrad, nullptr);
jfloat *hess = jenv->GetFloatArrayElements(jhess, nullptr);
CHECK(grad);
CHECK(hess);
xgboost::bst_target_t n_targets{1};
if (len != n_samples && n_samples != 0) {
CHECK_EQ(len % n_samples, 0) << "Invalid size of gradient.";
n_targets = len / n_samples;
}
auto ctx = xgboost::detail::BoosterCtx(handle);
auto [s_grad, s_hess] =
xgboost::detail::MakeGradientInterface(ctx, grad, hess, n_samples, n_targets);
int ret = XGBoosterTrainOneIter(handle, dtrain, static_cast<std::int32_t>(jiter), s_grad.c_str(),
s_hess.c_str());
// release // release
jenv->ReleaseFloatArrayElements(jgrad, grad, 0); jenv->ReleaseFloatArrayElements(jgrad, grad, 0);
jenv->ReleaseFloatArrayElements(jhess, hess, 0); jenv->ReleaseFloatArrayElements(jhess, hess, 0);
return ret; return ret;
API_END();
} }
/* /*

View File

@ -185,11 +185,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterUpdateOne
/* /*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI * Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterBoostOneIter * Method: XGBoosterTrainOneIter
* Signature: (JJ[F[F)I * Signature: (JJI[F[F)I
*/ */
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterBoostOneIter JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterTrainOneIter
(JNIEnv *, jclass, jlong, jlong, jfloatArray, jfloatArray); (JNIEnv *, jclass, jlong, jlong, jint, jfloatArray, jfloatArray);
/* /*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI * Class: ml_dmlc_xgboost4j_java_XGBoostJNI
@ -386,19 +386,17 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFro
/* /*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI * Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterSetStrFeatureInfo * Method: XGBoosterSetStrFeatureInfo
* Signature: (JLjava/lang/String;[Ljava/lang/String;])I * Signature: (JLjava/lang/String;[Ljava/lang/String;)I
*/ */
JNIEXPORT jint JNICALL JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSetStrFeatureInfo
Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSetStrFeatureInfo
(JNIEnv *, jclass, jlong, jstring, jobjectArray); (JNIEnv *, jclass, jlong, jstring, jobjectArray);
/* /*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI * Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterGetStrFeatureInfo * Method: XGBoosterGetStrFeatureInfo
* Signature: (JLjava/lang/String;[Ljava/lang/String;])I * Signature: (JLjava/lang/String;[Ljava/lang/String;)I
*/ */
JNIEXPORT jint JNICALL JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetStrFeatureInfo
Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetStrFeatureInfo
(JNIEnv *, jclass, jlong, jstring, jobjectArray); (JNIEnv *, jclass, jlong, jstring, jobjectArray);
#ifdef __cplusplus #ifdef __cplusplus

View File

@ -1,5 +1,5 @@
/*! /**
* Copyright 2015-2022 by Contributors * Copyright 2015-2023, XGBoost Contributors
* \file custom_metric.cc * \file custom_metric.cc
* \brief This is an example to define plugin of xgboost. * \brief This is an example to define plugin of xgboost.
* This plugin defines the additional metric function. * This plugin defines the additional metric function.
@ -9,9 +9,7 @@
#include <xgboost/objective.h> #include <xgboost/objective.h>
#include <xgboost/json.h> #include <xgboost/json.h>
namespace xgboost { namespace xgboost::obj {
namespace obj {
// This is a helpful data structure to define parameters // This is a helpful data structure to define parameters
// You do not have to use it. // You do not have to use it.
// see http://dmlc-core.readthedocs.org/en/latest/parameter.html // see http://dmlc-core.readthedocs.org/en/latest/parameter.html
@ -33,38 +31,38 @@ class MyLogistic : public ObjFunction {
public: public:
void Configure(const Args& args) override { param_.UpdateAllowUnknown(args); } void Configure(const Args& args) override { param_.UpdateAllowUnknown(args); }
ObjInfo Task() const override { return ObjInfo::kRegression; } [[nodiscard]] ObjInfo Task() const override { return ObjInfo::kRegression; }
void GetGradient(const HostDeviceVector<bst_float>& preds, const MetaInfo& info, int32_t /*iter*/, void GetGradient(const HostDeviceVector<float>& preds, MetaInfo const& info,
HostDeviceVector<GradientPair>* out_gpair) override { std::int32_t /*iter*/, linalg::Matrix<GradientPair>* out_gpair) override {
out_gpair->Resize(preds.Size()); out_gpair->Reshape(info.num_row_, 1);
const std::vector<bst_float>& preds_h = preds.HostVector(); const std::vector<float>& preds_h = preds.HostVector();
std::vector<GradientPair>& out_gpair_h = out_gpair->HostVector(); auto out_gpair_h = out_gpair->HostView();
auto const labels_h = info.labels.HostView(); auto const labels_h = info.labels.HostView();
for (size_t i = 0; i < preds_h.size(); ++i) { for (size_t i = 0; i < preds_h.size(); ++i) {
bst_float w = info.GetWeight(i); float w = info.GetWeight(i);
// scale the negative examples! // scale the negative examples!
if (labels_h(i) == 0.0f) w *= param_.scale_neg_weight; if (labels_h(i) == 0.0f) w *= param_.scale_neg_weight;
// logistic transformation // logistic transformation
bst_float p = 1.0f / (1.0f + std::exp(-preds_h[i])); float p = 1.0f / (1.0f + std::exp(-preds_h[i]));
// this is the gradient // this is the gradient
bst_float grad = (p - labels_h(i)) * w; float grad = (p - labels_h(i)) * w;
// this is the second order gradient // this is the second order gradient
bst_float hess = p * (1.0f - p) * w; float hess = p * (1.0f - p) * w;
out_gpair_h.at(i) = GradientPair(grad, hess); out_gpair_h(i) = GradientPair(grad, hess);
} }
} }
const char* DefaultEvalMetric() const override { [[nodiscard]] const char* DefaultEvalMetric() const override {
return "logloss"; return "logloss";
} }
void PredTransform(HostDeviceVector<bst_float> *io_preds) const override { void PredTransform(HostDeviceVector<float> *io_preds) const override {
// transform margin value to probability. // transform margin value to probability.
std::vector<bst_float> &preds = io_preds->HostVector(); std::vector<float> &preds = io_preds->HostVector();
for (auto& pred : preds) { for (auto& pred : preds) {
pred = 1.0f / (1.0f + std::exp(-pred)); pred = 1.0f / (1.0f + std::exp(-pred));
} }
} }
bst_float ProbToMargin(bst_float base_score) const override { [[nodiscard]] float ProbToMargin(float base_score) const override {
// transform probability to margin value // transform probability to margin value
return -std::log(1.0f / base_score - 1.0f); return -std::log(1.0f / base_score - 1.0f);
} }
@ -89,5 +87,4 @@ XGBOOST_REGISTER_OBJECTIVE(MyLogistic, "mylogistic")
.describe("User defined logistic regression plugin") .describe("User defined logistic regression plugin")
.set_body([]() { return new MyLogistic(); }); .set_body([]() { return new MyLogistic(); });
} // namespace obj } // namespace xgboost::obj
} // namespace xgboost

View File

@ -2053,12 +2053,14 @@ class Booster:
else: else:
pred = self.predict(dtrain, output_margin=True, training=True) pred = self.predict(dtrain, output_margin=True, training=True)
grad, hess = fobj(pred, dtrain) grad, hess = fobj(pred, dtrain)
self.boost(dtrain, grad, hess) self.boost(dtrain, iteration=iteration, grad=grad, hess=hess)
def boost(self, dtrain: DMatrix, grad: np.ndarray, hess: np.ndarray) -> None: def boost(
"""Boost the booster for one iteration, with customized gradient self, dtrain: DMatrix, iteration: int, grad: NumpyOrCupy, hess: NumpyOrCupy
statistics. Like :py:func:`xgboost.Booster.update`, this ) -> None:
function should not be called directly by users. """Boost the booster for one iteration with customized gradient statistics.
Like :py:func:`xgboost.Booster.update`, this function should not be called
directly by users.
Parameters Parameters
---------- ----------
@ -2070,19 +2072,53 @@ class Booster:
The second order of gradient. The second order of gradient.
""" """
if len(grad) != len(hess): from .data import (
raise ValueError(f"grad / hess length mismatch: {len(grad)} / {len(hess)}") _array_interface,
if not isinstance(dtrain, DMatrix): _cuda_array_interface,
raise TypeError(f"invalid training matrix: {type(dtrain).__name__}") _ensure_np_dtype,
_is_cupy_array,
)
self._assign_dmatrix_features(dtrain) self._assign_dmatrix_features(dtrain)
def is_flatten(array: NumpyOrCupy) -> bool:
return len(array.shape) == 1 or array.shape[1] == 1
def array_interface(array: NumpyOrCupy) -> bytes:
# Can we check for __array_interface__ instead of a specific type instead?
msg = (
"Expecting `np.ndarray` or `cupy.ndarray` for gradient and hessian."
f" Got: {type(array)}"
)
if not isinstance(array, np.ndarray) and not _is_cupy_array(array):
raise TypeError(msg)
n_samples = dtrain.num_row()
if array.shape[0] != n_samples and is_flatten(array):
warnings.warn(
"Since 2.1.0, the shape of the gradient and hessian is required to"
" be (n_samples, n_targets) or (n_samples, n_classes).",
FutureWarning,
)
array = array.reshape(n_samples, array.size // n_samples)
if isinstance(array, np.ndarray):
array, _ = _ensure_np_dtype(array, array.dtype)
interface = _array_interface(array)
elif _is_cupy_array(array):
interface = _cuda_array_interface(array)
else:
raise TypeError(msg)
return interface
_check_call( _check_call(
_LIB.XGBoosterBoostOneIter( _LIB.XGBoosterTrainOneIter(
self.handle, self.handle,
dtrain.handle, dtrain.handle,
c_array(ctypes.c_float, grad), iteration,
c_array(ctypes.c_float, hess), array_interface(grad),
c_bst_ulong(len(grad)), array_interface(hess),
) )
) )

View File

@ -763,13 +763,31 @@ def softmax(x: np.ndarray) -> np.ndarray:
return e / np.sum(e) return e / np.sum(e)
def softprob_obj(classes: int) -> SklObjective: def softprob_obj(
classes: int, use_cupy: bool = False, order: str = "C", gdtype: str = "float32"
) -> SklObjective:
"""Custom softprob objective for testing.
Parameters
----------
use_cupy :
Whether the objective should return cupy arrays.
order :
The order of gradient matrices. "C" or "F".
gdtype :
DType for gradient. Hessian is not set. This is for testing asymmetric types.
"""
if use_cupy:
import cupy as backend
else:
backend = np
def objective( def objective(
labels: np.ndarray, predt: np.ndarray labels: backend.ndarray, predt: backend.ndarray
) -> Tuple[np.ndarray, np.ndarray]: ) -> Tuple[backend.ndarray, backend.ndarray]:
rows = labels.shape[0] rows = labels.shape[0]
grad = np.zeros((rows, classes), dtype=float) grad = backend.zeros((rows, classes), dtype=np.float32)
hess = np.zeros((rows, classes), dtype=float) hess = backend.zeros((rows, classes), dtype=np.float32)
eps = 1e-6 eps = 1e-6
for r in range(predt.shape[0]): for r in range(predt.shape[0]):
target = labels[r] target = labels[r]
@ -781,8 +799,10 @@ def softprob_obj(classes: int) -> SklObjective:
grad[r, c] = g grad[r, c] = g
hess[r, c] = h hess[r, c] = h
grad = grad.reshape((rows * classes, 1)) grad = grad.reshape((rows, classes))
hess = hess.reshape((rows * classes, 1)) hess = hess.reshape((rows, classes))
grad = backend.require(grad, requirements=order, dtype=gdtype)
hess = backend.require(hess, requirements=order)
return grad, hess return grad, hess
return objective return objective

View File

@ -178,7 +178,7 @@ def train(
for i in range(start_iteration, num_boost_round): for i in range(start_iteration, num_boost_round):
if cb_container.before_iteration(bst, i, dtrain, evals): if cb_container.before_iteration(bst, i, dtrain, evals):
break break
bst.update(dtrain, i, obj) bst.update(dtrain, iteration=i, fobj=obj)
if cb_container.after_iteration(bst, i, dtrain, evals): if cb_container.after_iteration(bst, i, dtrain, evals):
break break

View File

@ -22,6 +22,7 @@
#include "../common/charconv.h" // for from_chars, to_chars, NumericLimits, from_ch... #include "../common/charconv.h" // for from_chars, to_chars, NumericLimits, from_ch...
#include "../common/hist_util.h" // for HistogramCuts #include "../common/hist_util.h" // for HistogramCuts
#include "../common/io.h" // for FileExtension, LoadSequentialFile, MemoryBuf... #include "../common/io.h" // for FileExtension, LoadSequentialFile, MemoryBuf...
#include "../common/linalg_op.h" // for ElementWiseTransformHost
#include "../common/threading_utils.h" // for OmpGetNumThreads, ParallelFor #include "../common/threading_utils.h" // for OmpGetNumThreads, ParallelFor
#include "../data/adapter.h" // for ArrayAdapter, DenseAdapter, RecordBatchesIte... #include "../data/adapter.h" // for ArrayAdapter, DenseAdapter, RecordBatchesIte...
#include "../data/ellpack_page.h" // for EllpackPage #include "../data/ellpack_page.h" // for EllpackPage
@ -68,6 +69,7 @@ XGB_DLL void XGBoostVersion(int* major, int* minor, int* patch) {
} }
} }
static_assert(DMLC_CXX11_THREAD_LOCAL, "XGBoost depends on thread-local storage.");
using GlobalConfigAPIThreadLocalStore = dmlc::ThreadLocalStore<XGBAPIThreadLocalEntry>; using GlobalConfigAPIThreadLocalStore = dmlc::ThreadLocalStore<XGBAPIThreadLocalEntry>;
#if !defined(XGBOOST_USE_CUDA) #if !defined(XGBOOST_USE_CUDA)
@ -717,8 +719,7 @@ XGB_DLL int XGDMatrixGetUIntInfo(const DMatrixHandle handle,
API_END(); API_END();
} }
XGB_DLL int XGDMatrixNumRow(const DMatrixHandle handle, XGB_DLL int XGDMatrixNumRow(DMatrixHandle handle, xgboost::bst_ulong *out) {
xgboost::bst_ulong *out) {
API_BEGIN(); API_BEGIN();
CHECK_HANDLE(); CHECK_HANDLE();
auto p_m = CastDMatrixHandle(handle); auto p_m = CastDMatrixHandle(handle);
@ -727,8 +728,7 @@ XGB_DLL int XGDMatrixNumRow(const DMatrixHandle handle,
API_END(); API_END();
} }
XGB_DLL int XGDMatrixNumCol(const DMatrixHandle handle, XGB_DLL int XGDMatrixNumCol(DMatrixHandle handle, xgboost::bst_ulong *out) {
xgboost::bst_ulong *out) {
API_BEGIN(); API_BEGIN();
CHECK_HANDLE(); CHECK_HANDLE();
auto p_m = CastDMatrixHandle(handle); auto p_m = CastDMatrixHandle(handle);
@ -970,28 +970,71 @@ XGB_DLL int XGBoosterUpdateOneIter(BoosterHandle handle,
API_END(); API_END();
} }
XGB_DLL int XGBoosterBoostOneIter(BoosterHandle handle, XGB_DLL int XGBoosterBoostOneIter(BoosterHandle handle, DMatrixHandle dtrain, bst_float *grad,
DMatrixHandle dtrain, bst_float *hess, xgboost::bst_ulong len) {
bst_float *grad,
bst_float *hess,
xgboost::bst_ulong len) {
API_BEGIN(); API_BEGIN();
CHECK_HANDLE(); CHECK_HANDLE();
HostDeviceVector<GradientPair> tmp_gpair; error::DeprecatedFunc(__func__, "2.1.0", "XGBoosterTrainOneIter");
auto* bst = static_cast<Learner*>(handle); auto *learner = static_cast<Learner *>(handle);
auto* dtr = auto ctx = learner->Ctx()->MakeCPU();
static_cast<std::shared_ptr<DMatrix>*>(dtrain);
tmp_gpair.Resize(len); auto t_grad = linalg::MakeTensorView(&ctx, common::Span{grad, len}, len);
std::vector<GradientPair>& tmp_gpair_h = tmp_gpair.HostVector(); auto t_hess = linalg::MakeTensorView(&ctx, common::Span{hess, len}, len);
if (len > 0) {
xgboost_CHECK_C_ARG_PTR(grad); auto s_grad = linalg::ArrayInterfaceStr(t_grad);
xgboost_CHECK_C_ARG_PTR(hess); auto s_hess = linalg::ArrayInterfaceStr(t_hess);
}
for (xgboost::bst_ulong i = 0; i < len; ++i) { return XGBoosterTrainOneIter(handle, dtrain, 0, s_grad.c_str(), s_hess.c_str());
tmp_gpair_h[i] = GradientPair(grad[i], hess[i]); API_END();
} }
bst->BoostOneIter(0, *dtr, &tmp_gpair); namespace xgboost {
// copy user-supplied CUDA gradient arrays
void CopyGradientFromCUDAArrays(Context const *, ArrayInterface<2, false> const &,
ArrayInterface<2, false> const &, linalg::Matrix<GradientPair> *)
#if !defined(XGBOOST_USE_CUDA)
{
common::AssertGPUSupport();
}
#else
; // NOLINT
#endif
} // namespace xgboost
XGB_DLL int XGBoosterTrainOneIter(BoosterHandle handle, DMatrixHandle dtrain, int iter,
char const *grad, char const *hess) {
API_BEGIN();
CHECK_HANDLE();
xgboost_CHECK_C_ARG_PTR(grad);
xgboost_CHECK_C_ARG_PTR(hess);
auto p_fmat = CastDMatrixHandle(dtrain);
ArrayInterface<2, false> i_grad{StringView{grad}};
ArrayInterface<2, false> i_hess{StringView{hess}};
StringView msg{"Mismatched shape between the gradient and hessian."};
CHECK_EQ(i_grad.Shape(0), i_hess.Shape(0)) << msg;
CHECK_EQ(i_grad.Shape(1), i_hess.Shape(1)) << msg;
linalg::Matrix<GradientPair> gpair;
auto grad_is_cuda = ArrayInterfaceHandler::IsCudaPtr(i_grad.data);
auto hess_is_cuda = ArrayInterfaceHandler::IsCudaPtr(i_hess.data);
CHECK_EQ(i_grad.Shape(0), p_fmat->Info().num_row_)
<< "Mismatched size between the gradient and training data.";
CHECK_EQ(grad_is_cuda, hess_is_cuda) << "gradient and hessian should be on the same device.";
auto *learner = static_cast<Learner *>(handle);
auto ctx = learner->Ctx();
if (!grad_is_cuda) {
gpair.Reshape(i_grad.Shape(0), i_grad.Shape(1));
auto const shape = gpair.Shape();
auto h_gpair = gpair.HostView();
DispatchDType(i_grad, DeviceOrd::CPU(), [&](auto &&t_grad) {
DispatchDType(i_hess, DeviceOrd::CPU(), [&](auto &&t_hess) {
common::ParallelFor(h_gpair.Size(), ctx->Threads(),
detail::CustomGradHessOp{t_grad, t_hess, h_gpair});
});
});
} else {
CopyGradientFromCUDAArrays(ctx, i_grad, i_hess, &gpair);
}
learner->BoostOneIter(iter, p_fmat, &gpair);
API_END(); API_END();
} }

View File

@ -1,8 +1,12 @@
/** /**
* Copyright 2019-2023 by XGBoost Contributors * Copyright 2019-2023 by XGBoost Contributors
*/ */
#include "../common/api_entry.h" // XGBAPIThreadLocalEntry #include <thrust/transform.h> // for transform
#include "../common/api_entry.h" // for XGBAPIThreadLocalEntry
#include "../common/cuda_context.cuh" // for CUDAContext
#include "../common/threading_utils.h" #include "../common/threading_utils.h"
#include "../data/array_interface.h" // for DispatchDType, ArrayInterface
#include "../data/device_adapter.cuh" #include "../data/device_adapter.cuh"
#include "../data/proxy_dmatrix.h" #include "../data/proxy_dmatrix.h"
#include "c_api_error.h" #include "c_api_error.h"
@ -13,7 +17,6 @@
#include "xgboost/learner.h" #include "xgboost/learner.h"
namespace xgboost { namespace xgboost {
void XGBBuildInfoDevice(Json *p_info) { void XGBBuildInfoDevice(Json *p_info) {
auto &info = *p_info; auto &info = *p_info;
@ -55,6 +58,27 @@ void XGBoostAPIGuard::RestoreGPUAttribute() {
// If errors, do nothing, assuming running on CPU only machine. // If errors, do nothing, assuming running on CPU only machine.
cudaSetDevice(device_id_); cudaSetDevice(device_id_);
} }
void CopyGradientFromCUDAArrays(Context const *ctx, ArrayInterface<2, false> const &grad,
ArrayInterface<2, false> const &hess,
linalg::Matrix<GradientPair> *out_gpair) {
auto grad_dev = dh::CudaGetPointerDevice(grad.data);
auto hess_dev = dh::CudaGetPointerDevice(hess.data);
CHECK_EQ(grad_dev, hess_dev) << "gradient and hessian should be on the same device.";
auto &gpair = *out_gpair;
gpair.SetDevice(grad_dev);
gpair.Reshape(grad.Shape(0), grad.Shape(1));
auto d_gpair = gpair.View(grad_dev);
auto cuctx = ctx->CUDACtx();
DispatchDType(grad, DeviceOrd::CUDA(grad_dev), [&](auto &&t_grad) {
DispatchDType(hess, DeviceOrd::CUDA(hess_dev), [&](auto &&t_hess) {
CHECK_EQ(t_grad.Size(), t_hess.Size());
thrust::for_each_n(cuctx->CTP(), thrust::make_counting_iterator(0ul), t_grad.Size(),
detail::CustomGradHessOp{t_grad, t_hess, d_gpair});
});
});
}
} // namespace xgboost } // namespace xgboost
using namespace xgboost; // NOLINT using namespace xgboost; // NOLINT

View File

@ -1,5 +1,5 @@
/*! /**
* Copyright (c) 2015-2022 by Contributors * Copyright 2015-2023, XGBoost Contributors
* \file c_api_error.h * \file c_api_error.h
* \brief Error handling for C API. * \brief Error handling for C API.
*/ */
@ -35,8 +35,8 @@
} \ } \
return 0; // NOLINT(*) return 0; // NOLINT(*)
#define CHECK_HANDLE() if (handle == nullptr) \ #define CHECK_HANDLE() \
LOG(FATAL) << "DMatrix/Booster has not been initialized or has already been disposed."; if (handle == nullptr) ::xgboost::detail::EmptyHandle();
/*! /*!
* \brief Set the last error message needed by C API * \brief Set the last error message needed by C API

View File

@ -7,8 +7,10 @@
#include <algorithm> #include <algorithm>
#include <cstddef> #include <cstddef>
#include <functional> #include <functional>
#include <memory> // std::shared_ptr #include <memory> // for shared_ptr
#include <string> #include <string> // for string
#include <tuple> // for make_tuple
#include <utility> // for move
#include <vector> #include <vector>
#include "xgboost/c_api.h" #include "xgboost/c_api.h"
@ -16,7 +18,7 @@
#include "xgboost/feature_map.h" // for FeatureMap #include "xgboost/feature_map.h" // for FeatureMap
#include "xgboost/json.h" #include "xgboost/json.h"
#include "xgboost/learner.h" #include "xgboost/learner.h"
#include "xgboost/linalg.h" // ArrayInterfaceHandler #include "xgboost/linalg.h" // ArrayInterfaceHandler, MakeTensorView, ArrayInterfaceStr
#include "xgboost/logging.h" #include "xgboost/logging.h"
#include "xgboost/string_view.h" // StringView #include "xgboost/string_view.h" // StringView
@ -287,6 +289,19 @@ inline std::shared_ptr<DMatrix> CastDMatrixHandle(DMatrixHandle const handle) {
} }
namespace detail { namespace detail {
inline void EmptyHandle() {
LOG(FATAL) << "DMatrix/Booster has not been initialized or has already been disposed.";
}
inline xgboost::Context const *BoosterCtx(BoosterHandle handle) {
if (handle == nullptr) {
EmptyHandle();
}
auto *learner = static_cast<xgboost::Learner *>(handle);
CHECK(learner);
return learner->Ctx();
}
template <typename PtrT, typename I, typename T> template <typename PtrT, typename I, typename T>
void MakeSparseFromPtr(PtrT const *p_indptr, I const *p_indices, T const *p_data, void MakeSparseFromPtr(PtrT const *p_indptr, I const *p_indices, T const *p_data,
std::size_t nindptr, std::string *indptr_str, std::string *indices_str, std::size_t nindptr, std::string *indptr_str, std::string *indices_str,
@ -334,6 +349,40 @@ void MakeSparseFromPtr(PtrT const *p_indptr, I const *p_indices, T const *p_data
Json::Dump(jindices, indices_str); Json::Dump(jindices, indices_str);
Json::Dump(jdata, data_str); Json::Dump(jdata, data_str);
} }
/**
* @brief Make array interface for other language bindings.
*/
template <typename G, typename H>
auto MakeGradientInterface(Context const *ctx, G const *grad, H const *hess, std::size_t n_samples,
std::size_t n_targets) {
auto t_grad =
linalg::MakeTensorView(ctx, common::Span{grad, n_samples * n_targets}, n_samples, n_targets);
auto t_hess =
linalg::MakeTensorView(ctx, common::Span{hess, n_samples * n_targets}, n_samples, n_targets);
auto s_grad = linalg::ArrayInterfaceStr(t_grad);
auto s_hess = linalg::ArrayInterfaceStr(t_hess);
return std::make_tuple(s_grad, s_hess);
}
template <typename G, typename H>
struct CustomGradHessOp {
linalg::MatrixView<G> t_grad;
linalg::MatrixView<H> t_hess;
linalg::MatrixView<GradientPair> d_gpair;
CustomGradHessOp(linalg::MatrixView<G> t_grad, linalg::MatrixView<H> t_hess,
linalg::MatrixView<GradientPair> d_gpair)
: t_grad{std::move(t_grad)}, t_hess{std::move(t_hess)}, d_gpair{std::move(d_gpair)} {}
XGBOOST_DEVICE void operator()(std::size_t i) {
auto [m, n] = linalg::UnravelIndex(i, t_grad.Shape(0), t_grad.Shape(1));
auto g = t_grad(m, n);
auto h = t_hess(m, n);
// from struct of arrays to array of structs.
d_gpair(m, n) = GradientPair{static_cast<float>(g), static_cast<float>(h)};
}
};
} // namespace detail } // namespace detail
} // namespace xgboost } // namespace xgboost
#endif // XGBOOST_C_API_C_API_UTILS_H_ #endif // XGBOOST_C_API_C_API_UTILS_H_

View File

@ -384,7 +384,7 @@ inline bool ArrayInterfaceHandler::IsCudaPtr(void const *) { return false; }
* numpy has the proper support even though it's in the __cuda_array_interface__ * numpy has the proper support even though it's in the __cuda_array_interface__
* protocol defined by numba. * protocol defined by numba.
*/ */
template <int32_t D, bool allow_mask = (D == 1)> template <std::int32_t D, bool allow_mask = (D == 1)>
class ArrayInterface { class ArrayInterface {
static_assert(D > 0, "Invalid dimension for array interface."); static_assert(D > 0, "Invalid dimension for array interface.");
@ -588,7 +588,7 @@ class ArrayInterface {
}; };
template <std::int32_t D, typename Fn> template <std::int32_t D, typename Fn>
void DispatchDType(ArrayInterface<D> const array, std::int32_t device, Fn fn) { void DispatchDType(ArrayInterface<D> const array, DeviceOrd device, Fn fn) {
// Only used for cuDF at the moment. // Only used for cuDF at the moment.
CHECK_EQ(array.valid.Capacity(), 0); CHECK_EQ(array.valid.Capacity(), 0);
auto dispatch = [&](auto t) { auto dispatch = [&](auto t) {

View File

@ -448,7 +448,7 @@ void CopyTensorInfoImpl(Context const& ctx, Json arr_interface, linalg::Tensor<T
auto t_out = p_out->View(Context::kCpuId); auto t_out = p_out->View(Context::kCpuId);
CHECK(t_out.CContiguous()); CHECK(t_out.CContiguous());
auto const shape = t_out.Shape(); auto const shape = t_out.Shape();
DispatchDType(array, Context::kCpuId, [&](auto&& in) { DispatchDType(array, DeviceOrd::CPU(), [&](auto&& in) {
linalg::ElementWiseTransformHost(t_out, ctx.Threads(), [&](auto i, auto) { linalg::ElementWiseTransformHost(t_out, ctx.Threads(), [&](auto i, auto) {
return std::apply(in, linalg::UnravelIndex<D>(i, shape)); return std::apply(in, linalg::UnravelIndex<D>(i, shape));
}); });

View File

@ -29,7 +29,6 @@
#include "../common/error_msg.h" #include "../common/error_msg.h"
namespace xgboost::gbm { namespace xgboost::gbm {
DMLC_REGISTRY_FILE_TAG(gblinear); DMLC_REGISTRY_FILE_TAG(gblinear);
// training parameters // training parameters
@ -142,7 +141,7 @@ class GBLinear : public GradientBooster {
this->updater_->SaveConfig(&j_updater); this->updater_->SaveConfig(&j_updater);
} }
void DoBoost(DMatrix* p_fmat, HostDeviceVector<GradientPair>* in_gpair, PredictionCacheEntry*, void DoBoost(DMatrix* p_fmat, linalg::Matrix<GradientPair>* in_gpair, PredictionCacheEntry*,
ObjFunction const*) override { ObjFunction const*) override {
monitor_.Start("DoBoost"); monitor_.Start("DoBoost");
@ -232,8 +231,7 @@ class GBLinear : public GradientBooster {
std::fill(contribs.begin(), contribs.end(), 0); std::fill(contribs.begin(), contribs.end(), 0);
} }
std::vector<std::string> DumpModel(const FeatureMap& fmap, [[nodiscard]] std::vector<std::string> DumpModel(const FeatureMap& fmap, bool with_stats,
bool with_stats,
std::string format) const override { std::string format) const override {
return model_.DumpModel(fmap, with_stats, format); return model_.DumpModel(fmap, with_stats, format);
} }
@ -263,7 +261,7 @@ class GBLinear : public GradientBooster {
} }
} }
bool UseGPU() const override { [[nodiscard]] bool UseGPU() const override {
if (param_.updater == "gpu_coord_descent") { if (param_.updater == "gpu_coord_descent") {
return true; return true;
} else { } else {

View File

@ -167,8 +167,8 @@ void GBTree::Configure(Args const& cfg) {
} }
} }
void GPUCopyGradient(HostDeviceVector<GradientPair> const*, bst_group_t, bst_group_t, void GPUCopyGradient(Context const*, linalg::Matrix<GradientPair> const*, bst_group_t,
HostDeviceVector<GradientPair>*) linalg::Matrix<GradientPair>*)
#if defined(XGBOOST_USE_CUDA) #if defined(XGBOOST_USE_CUDA)
; // NOLINT ; // NOLINT
#else #else
@ -177,16 +177,19 @@ void GPUCopyGradient(HostDeviceVector<GradientPair> const*, bst_group_t, bst_gro
} }
#endif #endif
void CopyGradient(HostDeviceVector<GradientPair> const* in_gpair, int32_t n_threads, void CopyGradient(Context const* ctx, linalg::Matrix<GradientPair> const* in_gpair,
bst_group_t n_groups, bst_group_t group_id, bst_group_t group_id, linalg::Matrix<GradientPair>* out_gpair) {
HostDeviceVector<GradientPair>* out_gpair) { out_gpair->SetDevice(ctx->Device());
if (in_gpair->DeviceIdx() != Context::kCpuId) { out_gpair->Reshape(in_gpair->Shape(0), 1);
GPUCopyGradient(in_gpair, n_groups, group_id, out_gpair); if (ctx->IsCUDA()) {
GPUCopyGradient(ctx, in_gpair, group_id, out_gpair);
} else { } else {
std::vector<GradientPair> &tmp_h = out_gpair->HostVector(); auto const& in = *in_gpair;
const auto& gpair_h = in_gpair->ConstHostVector(); auto target_gpair = in.Slice(linalg::All(), group_id);
common::ParallelFor(out_gpair->Size(), n_threads, auto h_tmp = out_gpair->HostView();
[&](auto i) { tmp_h[i] = gpair_h[i * n_groups + group_id]; }); auto h_in = in.HostView().Slice(linalg::All(), group_id);
CHECK_EQ(h_tmp.Size(), h_in.Size());
common::ParallelFor(h_in.Size(), ctx->Threads(), [&](auto i) { h_tmp(i) = h_in(i); });
} }
} }
@ -215,7 +218,7 @@ void GBTree::UpdateTreeLeaf(DMatrix const* p_fmat, HostDeviceVector<float> const
} }
} }
void GBTree::DoBoost(DMatrix* p_fmat, HostDeviceVector<GradientPair>* in_gpair, void GBTree::DoBoost(DMatrix* p_fmat, linalg::Matrix<GradientPair>* in_gpair,
PredictionCacheEntry* predt, ObjFunction const* obj) { PredictionCacheEntry* predt, ObjFunction const* obj) {
if (model_.learner_model_param->IsVectorLeaf()) { if (model_.learner_model_param->IsVectorLeaf()) {
CHECK(tparam_.tree_method == TreeMethod::kHist || tparam_.tree_method == TreeMethod::kAuto) CHECK(tparam_.tree_method == TreeMethod::kHist || tparam_.tree_method == TreeMethod::kAuto)
@ -263,12 +266,12 @@ void GBTree::DoBoost(DMatrix* p_fmat, HostDeviceVector<GradientPair>* in_gpair,
} }
} else { } else {
CHECK_EQ(in_gpair->Size() % n_groups, 0U) << "must have exactly ngroup * nrow gpairs"; CHECK_EQ(in_gpair->Size() % n_groups, 0U) << "must have exactly ngroup * nrow gpairs";
HostDeviceVector<GradientPair> tmp(in_gpair->Size() / n_groups, GradientPair(), linalg::Matrix<GradientPair> tmp{{in_gpair->Shape(0), static_cast<std::size_t>(1ul)},
in_gpair->DeviceIdx()); ctx_->Ordinal()};
bool update_predict = true; bool update_predict = true;
for (bst_target_t gid = 0; gid < n_groups; ++gid) { for (bst_target_t gid = 0; gid < n_groups; ++gid) {
node_position.clear(); node_position.clear();
CopyGradient(in_gpair, ctx_->Threads(), n_groups, gid, &tmp); CopyGradient(ctx_, in_gpair, gid, &tmp);
TreesOneGroup ret; TreesOneGroup ret;
BoostNewTrees(&tmp, p_fmat, gid, &node_position, &ret); BoostNewTrees(&tmp, p_fmat, gid, &node_position, &ret);
UpdateTreeLeaf(p_fmat, predt->predictions, obj, gid, node_position, &ret); UpdateTreeLeaf(p_fmat, predt->predictions, obj, gid, node_position, &ret);
@ -289,7 +292,7 @@ void GBTree::DoBoost(DMatrix* p_fmat, HostDeviceVector<GradientPair>* in_gpair,
this->CommitModel(std::move(new_trees)); this->CommitModel(std::move(new_trees));
} }
void GBTree::BoostNewTrees(HostDeviceVector<GradientPair>* gpair, DMatrix* p_fmat, int bst_group, void GBTree::BoostNewTrees(linalg::Matrix<GradientPair>* gpair, DMatrix* p_fmat, int bst_group,
std::vector<HostDeviceVector<bst_node_t>>* out_position, std::vector<HostDeviceVector<bst_node_t>>* out_position,
TreesOneGroup* ret) { TreesOneGroup* ret) {
std::vector<RegTree*> new_trees; std::vector<RegTree*> new_trees;

View File

@ -1,22 +1,24 @@
/** /**
* Copyright 2021-2023, XGBoost Contributors * Copyright 2021-2023, XGBoost Contributors
*/ */
#include "../common/device_helpers.cuh" #include <thrust/iterator/counting_iterator.h> // for make_counting_iterator
#include "xgboost/linalg.h"
#include "xgboost/span.h" #include "../common/cuda_context.cuh"
#include "../common/device_helpers.cuh" // for MakeTransformIterator
#include "xgboost/base.h" // for GradientPair
#include "xgboost/linalg.h" // for Matrix
namespace xgboost::gbm { namespace xgboost::gbm {
void GPUCopyGradient(HostDeviceVector<GradientPair> const *in_gpair, void GPUCopyGradient(Context const *ctx, linalg::Matrix<GradientPair> const *in_gpair,
bst_group_t n_groups, bst_group_t group_id, bst_group_t group_id, linalg::Matrix<GradientPair> *out_gpair) {
HostDeviceVector<GradientPair> *out_gpair) { auto v_in = in_gpair->View(ctx->Device()).Slice(linalg::All(), group_id);
auto mat = linalg::TensorView<GradientPair const, 2>( out_gpair->SetDevice(ctx->Device());
in_gpair->ConstDeviceSpan(), out_gpair->Reshape(v_in.Size(), 1);
{in_gpair->Size() / n_groups, static_cast<size_t>(n_groups)}, auto d_out = out_gpair->View(ctx->Device());
in_gpair->DeviceIdx()); auto cuctx = ctx->CUDACtx();
auto v_in = mat.Slice(linalg::All(), group_id); auto it = dh::MakeTransformIterator<GradientPair>(
out_gpair->Resize(v_in.Size()); thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(std::size_t i) { return v_in(i); });
auto d_out = out_gpair->DeviceSpan(); thrust::copy(cuctx->CTP(), it, it + v_in.Size(), d_out.Values().data());
dh::LaunchN(v_in.Size(), [=] __device__(size_t i) { d_out[i] = v_in(i); });
} }
void GPUDartPredictInc(common::Span<float> out_predts, void GPUDartPredictInc(common::Span<float> out_predts,

View File

@ -183,8 +183,8 @@ class GBTree : public GradientBooster {
/** /**
* @brief Carry out one iteration of boosting. * @brief Carry out one iteration of boosting.
*/ */
void DoBoost(DMatrix* p_fmat, HostDeviceVector<GradientPair>* in_gpair, void DoBoost(DMatrix* p_fmat, linalg::Matrix<GradientPair>* in_gpair, PredictionCacheEntry* predt,
PredictionCacheEntry* predt, ObjFunction const* obj) override; ObjFunction const* obj) override;
[[nodiscard]] bool UseGPU() const override { return tparam_.tree_method == TreeMethod::kGPUHist; } [[nodiscard]] bool UseGPU() const override { return tparam_.tree_method == TreeMethod::kGPUHist; }
@ -326,7 +326,7 @@ class GBTree : public GradientBooster {
} }
protected: protected:
void BoostNewTrees(HostDeviceVector<GradientPair>* gpair, DMatrix* p_fmat, int bst_group, void BoostNewTrees(linalg::Matrix<GradientPair>* gpair, DMatrix* p_fmat, int bst_group,
std::vector<HostDeviceVector<bst_node_t>>* out_position, std::vector<HostDeviceVector<bst_node_t>>* out_position,
std::vector<std::unique_ptr<RegTree>>* ret); std::vector<std::unique_ptr<RegTree>>* ret);

View File

@ -1282,14 +1282,14 @@ class LearnerImpl : public LearnerIO {
monitor_.Start("GetGradient"); monitor_.Start("GetGradient");
GetGradient(predt.predictions, train->Info(), iter, &gpair_); GetGradient(predt.predictions, train->Info(), iter, &gpair_);
monitor_.Stop("GetGradient"); monitor_.Stop("GetGradient");
TrainingObserver::Instance().Observe(gpair_, "Gradients"); TrainingObserver::Instance().Observe(*gpair_.Data(), "Gradients");
gbm_->DoBoost(train.get(), &gpair_, &predt, obj_.get()); gbm_->DoBoost(train.get(), &gpair_, &predt, obj_.get());
monitor_.Stop("UpdateOneIter"); monitor_.Stop("UpdateOneIter");
} }
void BoostOneIter(int iter, std::shared_ptr<DMatrix> train, void BoostOneIter(int iter, std::shared_ptr<DMatrix> train,
HostDeviceVector<GradientPair>* in_gpair) override { linalg::Matrix<GradientPair>* in_gpair) override {
monitor_.Start("BoostOneIter"); monitor_.Start("BoostOneIter");
this->Configure(); this->Configure();
@ -1299,6 +1299,9 @@ class LearnerImpl : public LearnerIO {
this->ValidateDMatrix(train.get(), true); this->ValidateDMatrix(train.get(), true);
CHECK_EQ(this->learner_model_param_.OutputLength(), in_gpair->Shape(1))
<< "The number of columns in gradient should be equal to the number of targets/classes in "
"the model.";
auto& predt = prediction_container_.Cache(train, ctx_.gpu_id); auto& predt = prediction_container_.Cache(train, ctx_.gpu_id);
gbm_->DoBoost(train.get(), in_gpair, &predt, obj_.get()); gbm_->DoBoost(train.get(), in_gpair, &predt, obj_.get());
monitor_.Stop("BoostOneIter"); monitor_.Stop("BoostOneIter");
@ -1461,18 +1464,18 @@ class LearnerImpl : public LearnerIO {
} }
private: private:
void GetGradient(HostDeviceVector<bst_float> const& preds, MetaInfo const& info, int iteration, void GetGradient(HostDeviceVector<bst_float> const& preds, MetaInfo const& info,
HostDeviceVector<GradientPair>* out_gpair) { std::int32_t iter, linalg::Matrix<GradientPair>* out_gpair) {
out_gpair->Resize(preds.Size()); out_gpair->Reshape(info.num_row_, this->learner_model_param_.OutputLength());
collective::ApplyWithLabels(info, out_gpair->HostPointer(), collective::ApplyWithLabels(info, out_gpair->Data()->HostPointer(),
out_gpair->Size() * sizeof(GradientPair), out_gpair->Size() * sizeof(GradientPair),
[&] { obj_->GetGradient(preds, info, iteration, out_gpair); }); [&] { obj_->GetGradient(preds, info, iter, out_gpair); });
} }
/*! \brief random number transformation seed. */ /*! \brief random number transformation seed. */
static int32_t constexpr kRandSeedMagic = 127; static int32_t constexpr kRandSeedMagic = 127;
// gradient pairs // gradient pairs
HostDeviceVector<GradientPair> gpair_; linalg::Matrix<GradientPair> gpair_;
/*! \brief Temporary storage to prediction. Useful for storing data transformed by /*! \brief Temporary storage to prediction. Useful for storing data transformed by
* objective function */ * objective function */
PredictionContainer output_predictions_; PredictionContainer output_predictions_;

View File

@ -45,30 +45,31 @@ class CoordinateUpdater : public LinearUpdater {
out["coordinate_param"] = ToJson(cparam_); out["coordinate_param"] = ToJson(cparam_);
} }
void Update(HostDeviceVector<GradientPair> *in_gpair, DMatrix *p_fmat, void Update(linalg::Matrix<GradientPair> *in_gpair, DMatrix *p_fmat, gbm::GBLinearModel *model,
gbm::GBLinearModel *model, double sum_instance_weight) override { double sum_instance_weight) override {
auto gpair = in_gpair->Data();
tparam_.DenormalizePenalties(sum_instance_weight); tparam_.DenormalizePenalties(sum_instance_weight);
const int ngroup = model->learner_model_param->num_output_group; const int ngroup = model->learner_model_param->num_output_group;
// update bias // update bias
for (int group_idx = 0; group_idx < ngroup; ++group_idx) { for (int group_idx = 0; group_idx < ngroup; ++group_idx) {
auto grad = GetBiasGradientParallel(group_idx, ngroup, in_gpair->ConstHostVector(), p_fmat, auto grad = GetBiasGradientParallel(group_idx, ngroup, gpair->ConstHostVector(), p_fmat,
ctx_->Threads()); ctx_->Threads());
auto dbias = static_cast<float>(tparam_.learning_rate * auto dbias = static_cast<float>(tparam_.learning_rate *
CoordinateDeltaBias(grad.first, grad.second)); CoordinateDeltaBias(grad.first, grad.second));
model->Bias()[group_idx] += dbias; model->Bias()[group_idx] += dbias;
UpdateBiasResidualParallel(ctx_, group_idx, ngroup, dbias, &in_gpair->HostVector(), p_fmat); UpdateBiasResidualParallel(ctx_, group_idx, ngroup, dbias, &gpair->HostVector(), p_fmat);
} }
// prepare for updating the weights // prepare for updating the weights
selector_->Setup(ctx_, *model, in_gpair->ConstHostVector(), p_fmat, tparam_.reg_alpha_denorm, selector_->Setup(ctx_, *model, gpair->ConstHostVector(), p_fmat, tparam_.reg_alpha_denorm,
tparam_.reg_lambda_denorm, cparam_.top_k); tparam_.reg_lambda_denorm, cparam_.top_k);
// update weights // update weights
for (int group_idx = 0; group_idx < ngroup; ++group_idx) { for (int group_idx = 0; group_idx < ngroup; ++group_idx) {
for (unsigned i = 0U; i < model->learner_model_param->num_feature; i++) { for (unsigned i = 0U; i < model->learner_model_param->num_feature; i++) {
int fidx = int fidx =
selector_->NextFeature(ctx_, i, *model, group_idx, in_gpair->ConstHostVector(), p_fmat, selector_->NextFeature(ctx_, i, *model, group_idx, gpair->ConstHostVector(), p_fmat,
tparam_.reg_alpha_denorm, tparam_.reg_lambda_denorm); tparam_.reg_alpha_denorm, tparam_.reg_lambda_denorm);
if (fidx < 0) break; if (fidx < 0) break;
this->UpdateFeature(fidx, group_idx, &in_gpair->HostVector(), p_fmat, model); this->UpdateFeature(fidx, group_idx, &gpair->HostVector(), p_fmat, model);
} }
} }
monitor_.Stop("UpdateFeature"); monitor_.Stop("UpdateFeature");

View File

@ -93,17 +93,18 @@ class GPUCoordinateUpdater : public LinearUpdater { // NOLINT
} }
} }
void Update(HostDeviceVector<GradientPair> *in_gpair, DMatrix *p_fmat, void Update(linalg::Matrix<GradientPair> *in_gpair, DMatrix *p_fmat, gbm::GBLinearModel *model,
gbm::GBLinearModel *model, double sum_instance_weight) override { double sum_instance_weight) override {
tparam_.DenormalizePenalties(sum_instance_weight); tparam_.DenormalizePenalties(sum_instance_weight);
monitor_.Start("LazyInitDevice"); monitor_.Start("LazyInitDevice");
this->LazyInitDevice(p_fmat, *(model->learner_model_param)); this->LazyInitDevice(p_fmat, *(model->learner_model_param));
monitor_.Stop("LazyInitDevice"); monitor_.Stop("LazyInitDevice");
monitor_.Start("UpdateGpair"); monitor_.Start("UpdateGpair");
// Update gpair // Update gpair
if (ctx_->gpu_id >= 0) { if (ctx_->IsCUDA()) {
this->UpdateGpair(in_gpair->ConstHostVector()); this->UpdateGpair(in_gpair->Data()->ConstHostVector());
} }
monitor_.Stop("UpdateGpair"); monitor_.Stop("UpdateGpair");
@ -111,15 +112,15 @@ class GPUCoordinateUpdater : public LinearUpdater { // NOLINT
this->UpdateBias(model); this->UpdateBias(model);
monitor_.Stop("UpdateBias"); monitor_.Stop("UpdateBias");
// prepare for updating the weights // prepare for updating the weights
selector_->Setup(ctx_, *model, in_gpair->ConstHostVector(), p_fmat, tparam_.reg_alpha_denorm, selector_->Setup(ctx_, *model, in_gpair->Data()->ConstHostVector(), p_fmat,
tparam_.reg_lambda_denorm, coord_param_.top_k); tparam_.reg_alpha_denorm, tparam_.reg_lambda_denorm, coord_param_.top_k);
monitor_.Start("UpdateFeature"); monitor_.Start("UpdateFeature");
for (uint32_t group_idx = 0; group_idx < model->learner_model_param->num_output_group; for (uint32_t group_idx = 0; group_idx < model->learner_model_param->num_output_group;
++group_idx) { ++group_idx) {
for (auto i = 0U; i < model->learner_model_param->num_feature; i++) { for (auto i = 0U; i < model->learner_model_param->num_feature; i++) {
auto fidx = auto fidx =
selector_->NextFeature(ctx_, i, *model, group_idx, in_gpair->ConstHostVector(), p_fmat, selector_->NextFeature(ctx_, i, *model, group_idx, in_gpair->Data()->ConstHostVector(),
tparam_.reg_alpha_denorm, tparam_.reg_lambda_denorm); p_fmat, tparam_.reg_alpha_denorm, tparam_.reg_lambda_denorm);
if (fidx < 0) break; if (fidx < 0) break;
this->UpdateFeature(fidx, group_idx, model); this->UpdateFeature(fidx, group_idx, model);
} }

View File

@ -6,8 +6,7 @@
#include <xgboost/linear_updater.h> #include <xgboost/linear_updater.h>
#include "coordinate_common.h" #include "coordinate_common.h"
namespace xgboost { namespace xgboost::linear {
namespace linear {
DMLC_REGISTRY_FILE_TAG(updater_shotgun); DMLC_REGISTRY_FILE_TAG(updater_shotgun);
@ -32,30 +31,31 @@ class ShotgunUpdater : public LinearUpdater {
out["linear_train_param"] = ToJson(param_); out["linear_train_param"] = ToJson(param_);
} }
void Update(HostDeviceVector<GradientPair> *in_gpair, DMatrix *p_fmat, void Update(linalg::Matrix<GradientPair> *in_gpair, DMatrix *p_fmat, gbm::GBLinearModel *model,
gbm::GBLinearModel *model, double sum_instance_weight) override { double sum_instance_weight) override {
auto &gpair = in_gpair->HostVector(); auto gpair = in_gpair->Data();
param_.DenormalizePenalties(sum_instance_weight); param_.DenormalizePenalties(sum_instance_weight);
const int ngroup = model->learner_model_param->num_output_group; const int ngroup = model->learner_model_param->num_output_group;
// update bias // update bias
for (int gid = 0; gid < ngroup; ++gid) { for (int gid = 0; gid < ngroup; ++gid) {
auto grad = GetBiasGradientParallel(gid, ngroup, in_gpair->ConstHostVector(), p_fmat, auto grad = GetBiasGradientParallel(gid, ngroup, gpair->ConstHostVector(), p_fmat,
ctx_->Threads()); ctx_->Threads());
auto dbias = static_cast<bst_float>(param_.learning_rate * auto dbias = static_cast<bst_float>(param_.learning_rate *
CoordinateDeltaBias(grad.first, grad.second)); CoordinateDeltaBias(grad.first, grad.second));
model->Bias()[gid] += dbias; model->Bias()[gid] += dbias;
UpdateBiasResidualParallel(ctx_, gid, ngroup, dbias, &in_gpair->HostVector(), p_fmat); UpdateBiasResidualParallel(ctx_, gid, ngroup, dbias, &gpair->HostVector(), p_fmat);
} }
// lock-free parallel updates of weights // lock-free parallel updates of weights
selector_->Setup(ctx_, *model, in_gpair->ConstHostVector(), p_fmat, param_.reg_alpha_denorm, selector_->Setup(ctx_, *model, gpair->ConstHostVector(), p_fmat, param_.reg_alpha_denorm,
param_.reg_lambda_denorm, 0); param_.reg_lambda_denorm, 0);
auto &h_gpair = gpair->HostVector();
for (const auto &batch : p_fmat->GetBatches<CSCPage>(ctx_)) { for (const auto &batch : p_fmat->GetBatches<CSCPage>(ctx_)) {
auto page = batch.GetView(); auto page = batch.GetView();
const auto nfeat = static_cast<bst_omp_uint>(batch.Size()); const auto nfeat = static_cast<bst_omp_uint>(batch.Size());
common::ParallelFor(nfeat, ctx_->Threads(), [&](auto i) { common::ParallelFor(nfeat, ctx_->Threads(), [&](auto i) {
int ii = selector_->NextFeature(ctx_, i, *model, 0, in_gpair->ConstHostVector(), p_fmat, int ii = selector_->NextFeature(ctx_, i, *model, 0, gpair->ConstHostVector(), p_fmat,
param_.reg_alpha_denorm, param_.reg_lambda_denorm); param_.reg_alpha_denorm, param_.reg_lambda_denorm);
if (ii < 0) return; if (ii < 0) return;
const bst_uint fid = ii; const bst_uint fid = ii;
@ -63,7 +63,7 @@ class ShotgunUpdater : public LinearUpdater {
for (int gid = 0; gid < ngroup; ++gid) { for (int gid = 0; gid < ngroup; ++gid) {
double sum_grad = 0.0, sum_hess = 0.0; double sum_grad = 0.0, sum_hess = 0.0;
for (auto &c : col) { for (auto &c : col) {
const GradientPair &p = gpair[c.index * ngroup + gid]; const GradientPair &p = h_gpair[c.index * ngroup + gid];
if (p.GetHess() < 0.0f) continue; if (p.GetHess() < 0.0f) continue;
const bst_float v = c.fvalue; const bst_float v = c.fvalue;
sum_grad += p.GetGrad() * v; sum_grad += p.GetGrad() * v;
@ -77,7 +77,7 @@ class ShotgunUpdater : public LinearUpdater {
w += dw; w += dw;
// update grad values // update grad values
for (auto &c : col) { for (auto &c : col) {
GradientPair &p = gpair[c.index * ngroup + gid]; GradientPair &p = h_gpair[c.index * ngroup + gid];
if (p.GetHess() < 0.0f) continue; if (p.GetHess() < 0.0f) continue;
p += GradientPair(p.GetHess() * c.fvalue * dw, 0); p += GradientPair(p.GetHess() * c.fvalue * dw, 0);
} }
@ -98,5 +98,4 @@ XGBOOST_REGISTER_LINEAR_UPDATER(ShotgunUpdater, "shotgun")
"Update linear model according to shotgun coordinate descent " "Update linear model according to shotgun coordinate descent "
"algorithm.") "algorithm.")
.set_body([]() { return new ShotgunUpdater(); }); .set_body([]() { return new ShotgunUpdater(); });
} // namespace linear } // namespace xgboost::linear
} // namespace xgboost

View File

@ -1,5 +1,5 @@
/*! /**
* Copyright 2019-2022 by Contributors * Copyright 2019-2023, XGBoost Contributors
* \file aft_obj.cu * \file aft_obj.cu
* \brief Definition of AFT loss for survival analysis. * \brief Definition of AFT loss for survival analysis.
* \author Avinash Barnwal, Hyunsu Cho and Toby Hocking * \author Avinash Barnwal, Hyunsu Cho and Toby Hocking
@ -41,11 +41,9 @@ class AFTObj : public ObjFunction {
ObjInfo Task() const override { return ObjInfo::kSurvival; } ObjInfo Task() const override { return ObjInfo::kSurvival; }
template <typename Distribution> template <typename Distribution>
void GetGradientImpl(const HostDeviceVector<bst_float> &preds, void GetGradientImpl(const HostDeviceVector<bst_float>& preds, const MetaInfo& info,
const MetaInfo &info, linalg::Matrix<GradientPair>* out_gpair, size_t ndata, int device,
HostDeviceVector<GradientPair> *out_gpair, bool is_null_weight, float aft_loss_distribution_scale) {
size_t ndata, int device, bool is_null_weight,
float aft_loss_distribution_scale) {
common::Transform<>::Init( common::Transform<>::Init(
[=] XGBOOST_DEVICE(size_t _idx, [=] XGBOOST_DEVICE(size_t _idx,
common::Span<GradientPair> _out_gpair, common::Span<GradientPair> _out_gpair,
@ -66,16 +64,17 @@ class AFTObj : public ObjFunction {
_out_gpair[_idx] = GradientPair(grad * w, hess * w); _out_gpair[_idx] = GradientPair(grad * w, hess * w);
}, },
common::Range{0, static_cast<int64_t>(ndata)}, this->ctx_->Threads(), device).Eval( common::Range{0, static_cast<int64_t>(ndata)}, this->ctx_->Threads(), device).Eval(
out_gpair, &preds, &info.labels_lower_bound_, &info.labels_upper_bound_, out_gpair->Data(), &preds, &info.labels_lower_bound_, &info.labels_upper_bound_,
&info.weights_); &info.weights_);
} }
void GetGradient(const HostDeviceVector<bst_float>& preds, const MetaInfo& info, int /*iter*/, void GetGradient(const HostDeviceVector<bst_float>& preds, const MetaInfo& info, int /*iter*/,
HostDeviceVector<GradientPair>* out_gpair) override { linalg::Matrix<GradientPair>* out_gpair) override {
const size_t ndata = preds.Size(); const size_t ndata = preds.Size();
CHECK_EQ(info.labels_lower_bound_.Size(), ndata); CHECK_EQ(info.labels_lower_bound_.Size(), ndata);
CHECK_EQ(info.labels_upper_bound_.Size(), ndata); CHECK_EQ(info.labels_upper_bound_.Size(), ndata);
out_gpair->Resize(ndata); out_gpair->SetDevice(ctx_->Device());
out_gpair->Reshape(ndata, 1);
const int device = ctx_->gpu_id; const int device = ctx_->gpu_id;
const float aft_loss_distribution_scale = param_.aft_loss_distribution_scale; const float aft_loss_distribution_scale = param_.aft_loss_distribution_scale;
const bool is_null_weight = info.weights_.Size() == 0; const bool is_null_weight = info.weights_.Size() == 0;

View File

@ -27,8 +27,8 @@ class HingeObj : public ObjFunction {
void Configure(Args const&) override {} void Configure(Args const&) override {}
ObjInfo Task() const override { return ObjInfo::kRegression; } ObjInfo Task() const override { return ObjInfo::kRegression; }
void GetGradient(const HostDeviceVector<bst_float> &preds, const MetaInfo &info, int /*iter*/, void GetGradient(const HostDeviceVector<bst_float> &preds, const MetaInfo &info,
HostDeviceVector<GradientPair> *out_gpair) override { std::int32_t /*iter*/, linalg::Matrix<GradientPair> *out_gpair) override {
CHECK_NE(info.labels.Size(), 0U) << "label set cannot be empty"; CHECK_NE(info.labels.Size(), 0U) << "label set cannot be empty";
CHECK_EQ(preds.Size(), info.labels.Size()) CHECK_EQ(preds.Size(), info.labels.Size())
<< "labels are not correctly provided" << "labels are not correctly provided"
@ -41,7 +41,8 @@ class HingeObj : public ObjFunction {
CHECK_EQ(info.weights_.Size(), ndata) CHECK_EQ(info.weights_.Size(), ndata)
<< "Number of weights should be equal to number of data points."; << "Number of weights should be equal to number of data points.";
} }
out_gpair->Resize(ndata); CHECK_EQ(info.labels.Shape(1), 1) << "Multi-target for `binary:hinge` is not yet supported.";
out_gpair->Reshape(ndata, 1);
common::Transform<>::Init( common::Transform<>::Init(
[=] XGBOOST_DEVICE(size_t _idx, [=] XGBOOST_DEVICE(size_t _idx,
common::Span<GradientPair> _out_gpair, common::Span<GradientPair> _out_gpair,
@ -63,7 +64,7 @@ class HingeObj : public ObjFunction {
}, },
common::Range{0, static_cast<int64_t>(ndata)}, this->ctx_->Threads(), common::Range{0, static_cast<int64_t>(ndata)}, this->ctx_->Threads(),
ctx_->gpu_id).Eval( ctx_->gpu_id).Eval(
out_gpair, &preds, info.labels.Data(), &info.weights_); out_gpair->Data(), &preds, info.labels.Data(), &info.weights_);
} }
void PredTransform(HostDeviceVector<bst_float> *io_preds) const override { void PredTransform(HostDeviceVector<bst_float> *io_preds) const override {

View File

@ -21,7 +21,7 @@ void FitIntercept::InitEstimation(MetaInfo const& info, linalg::Vector<float>* b
} }
// Avoid altering any state in child objective. // Avoid altering any state in child objective.
HostDeviceVector<float> dummy_predt(info.labels.Size(), 0.0f, this->ctx_->gpu_id); HostDeviceVector<float> dummy_predt(info.labels.Size(), 0.0f, this->ctx_->gpu_id);
HostDeviceVector<GradientPair> gpair(info.labels.Size(), GradientPair{}, this->ctx_->gpu_id); linalg::Matrix<GradientPair> gpair(info.labels.Shape(), this->ctx_->gpu_id);
Json config{Object{}}; Json config{Object{}};
this->SaveConfig(&config); this->SaveConfig(&config);

View File

@ -165,9 +165,8 @@ class LambdaRankObj : public FitIntercept {
void CalcLambdaForGroup(std::int32_t iter, common::Span<float const> g_predt, void CalcLambdaForGroup(std::int32_t iter, common::Span<float const> g_predt,
linalg::VectorView<float const> g_label, float w, linalg::VectorView<float const> g_label, float w,
common::Span<std::size_t const> g_rank, bst_group_t g, Delta delta, common::Span<std::size_t const> g_rank, bst_group_t g, Delta delta,
common::Span<GradientPair> g_gpair) { linalg::VectorView<GradientPair> g_gpair) {
std::fill_n(g_gpair.data(), g_gpair.size(), GradientPair{}); std::fill_n(g_gpair.Values().data(), g_gpair.Size(), GradientPair{});
auto p_gpair = g_gpair.data();
auto ti_plus = ti_plus_.HostView(); auto ti_plus = ti_plus_.HostView();
auto tj_minus = tj_minus_.HostView(); auto tj_minus = tj_minus_.HostView();
@ -198,8 +197,8 @@ class LambdaRankObj : public FitIntercept {
std::size_t idx_high = g_rank[rank_high]; std::size_t idx_high = g_rank[rank_high];
std::size_t idx_low = g_rank[rank_low]; std::size_t idx_low = g_rank[rank_low];
p_gpair[idx_high] += pg; g_gpair(idx_high) += pg;
p_gpair[idx_low] += ng; g_gpair(idx_low) += ng;
if (unbiased) { if (unbiased) {
auto k = ti_plus.Size(); auto k = ti_plus.Size();
@ -225,12 +224,13 @@ class LambdaRankObj : public FitIntercept {
MakePairs(ctx_, iter, p_cache_, g, g_label, g_rank, loop); MakePairs(ctx_, iter, p_cache_, g, g_label, g_rank, loop);
if (sum_lambda > 0.0) { if (sum_lambda > 0.0) {
double norm = std::log2(1.0 + sum_lambda) / sum_lambda; double norm = std::log2(1.0 + sum_lambda) / sum_lambda;
std::transform(g_gpair.data(), g_gpair.data() + g_gpair.size(), g_gpair.data(), std::transform(g_gpair.Values().data(), g_gpair.Values().data() + g_gpair.Size(),
[norm](GradientPair const& g) { return g * norm; }); g_gpair.Values().data(), [norm](GradientPair const& g) { return g * norm; });
} }
auto w_norm = p_cache_->WeightNorm(); auto w_norm = p_cache_->WeightNorm();
std::transform(g_gpair.begin(), g_gpair.end(), g_gpair.begin(), std::transform(g_gpair.Values().data(), g_gpair.Values().data() + g_gpair.Size(),
g_gpair.Values().data(),
[&](GradientPair const& gpair) { return gpair * w * w_norm; }); [&](GradientPair const& gpair) { return gpair * w * w_norm; });
} }
@ -301,7 +301,7 @@ class LambdaRankObj : public FitIntercept {
} }
void GetGradient(HostDeviceVector<float> const& predt, MetaInfo const& info, std::int32_t iter, void GetGradient(HostDeviceVector<float> const& predt, MetaInfo const& info, std::int32_t iter,
HostDeviceVector<GradientPair>* out_gpair) override { linalg::Matrix<GradientPair>* out_gpair) override {
CHECK_EQ(info.labels.Size(), predt.Size()) << error::LabelScoreSize(); CHECK_EQ(info.labels.Size(), predt.Size()) << error::LabelScoreSize();
// init/renew cache // init/renew cache
@ -339,7 +339,7 @@ class LambdaRankNDCG : public LambdaRankObj<LambdaRankNDCG, ltr::NDCGCache> {
void CalcLambdaForGroupNDCG(std::int32_t iter, common::Span<float const> g_predt, void CalcLambdaForGroupNDCG(std::int32_t iter, common::Span<float const> g_predt,
linalg::VectorView<float const> g_label, float w, linalg::VectorView<float const> g_label, float w,
common::Span<std::size_t const> g_rank, common::Span<std::size_t const> g_rank,
common::Span<GradientPair> g_gpair, linalg::VectorView<GradientPair> g_gpair,
linalg::VectorView<double const> inv_IDCG, linalg::VectorView<double const> inv_IDCG,
common::Span<double const> discount, bst_group_t g) { common::Span<double const> discount, bst_group_t g) {
auto delta = [&](auto y_high, auto y_low, std::size_t rank_high, std::size_t rank_low, auto delta = [&](auto y_high, auto y_low, std::size_t rank_high, std::size_t rank_low,
@ -351,7 +351,7 @@ class LambdaRankNDCG : public LambdaRankObj<LambdaRankNDCG, ltr::NDCGCache> {
} }
void GetGradientImpl(std::int32_t iter, const HostDeviceVector<float>& predt, void GetGradientImpl(std::int32_t iter, const HostDeviceVector<float>& predt,
const MetaInfo& info, HostDeviceVector<GradientPair>* out_gpair) { const MetaInfo& info, linalg::Matrix<GradientPair>* out_gpair) {
if (ctx_->IsCUDA()) { if (ctx_->IsCUDA()) {
cuda_impl::LambdaRankGetGradientNDCG( cuda_impl::LambdaRankGetGradientNDCG(
ctx_, iter, predt, info, GetCache(), ti_plus_.View(ctx_->gpu_id), ctx_, iter, predt, info, GetCache(), ti_plus_.View(ctx_->gpu_id),
@ -363,8 +363,10 @@ class LambdaRankNDCG : public LambdaRankObj<LambdaRankNDCG, ltr::NDCGCache> {
bst_group_t n_groups = p_cache_->Groups(); bst_group_t n_groups = p_cache_->Groups();
auto gptr = p_cache_->DataGroupPtr(ctx_); auto gptr = p_cache_->DataGroupPtr(ctx_);
out_gpair->Resize(info.num_row_); out_gpair->SetDevice(ctx_->Device());
auto h_gpair = out_gpair->HostSpan(); out_gpair->Reshape(info.num_row_, 1);
auto h_gpair = out_gpair->HostView();
auto h_predt = predt.ConstHostSpan(); auto h_predt = predt.ConstHostSpan();
auto h_label = info.labels.HostView(); auto h_label = info.labels.HostView();
auto h_weight = common::MakeOptionalWeights(ctx_, info.weights_); auto h_weight = common::MakeOptionalWeights(ctx_, info.weights_);
@ -378,7 +380,8 @@ class LambdaRankNDCG : public LambdaRankObj<LambdaRankNDCG, ltr::NDCGCache> {
std::size_t cnt = gptr[g + 1] - gptr[g]; std::size_t cnt = gptr[g + 1] - gptr[g];
auto w = h_weight[g]; auto w = h_weight[g];
auto g_predt = h_predt.subspan(gptr[g], cnt); auto g_predt = h_predt.subspan(gptr[g], cnt);
auto g_gpair = h_gpair.subspan(gptr[g], cnt); auto g_gpair =
h_gpair.Slice(linalg::Range(static_cast<std::size_t>(gptr[g]), gptr[g] + cnt), 0);
auto g_label = h_label.Slice(make_range(g), 0); auto g_label = h_label.Slice(make_range(g), 0);
auto g_rank = rank_idx.subspan(gptr[g], cnt); auto g_rank = rank_idx.subspan(gptr[g], cnt);
@ -420,7 +423,7 @@ void LambdaRankGetGradientNDCG(Context const*, std::int32_t, HostDeviceVector<fl
linalg::VectorView<double const>, // input bias ratio linalg::VectorView<double const>, // input bias ratio
linalg::VectorView<double const>, // input bias ratio linalg::VectorView<double const>, // input bias ratio
linalg::VectorView<double>, linalg::VectorView<double>, linalg::VectorView<double>, linalg::VectorView<double>,
HostDeviceVector<GradientPair>*) { linalg::Matrix<GradientPair>*) {
common::AssertGPUSupport(); common::AssertGPUSupport();
} }
@ -470,7 +473,7 @@ void MAPStat(Context const* ctx, linalg::VectorView<float const> label,
class LambdaRankMAP : public LambdaRankObj<LambdaRankMAP, ltr::MAPCache> { class LambdaRankMAP : public LambdaRankObj<LambdaRankMAP, ltr::MAPCache> {
public: public:
void GetGradientImpl(std::int32_t iter, const HostDeviceVector<float>& predt, void GetGradientImpl(std::int32_t iter, const HostDeviceVector<float>& predt,
const MetaInfo& info, HostDeviceVector<GradientPair>* out_gpair) { const MetaInfo& info, linalg::Matrix<GradientPair>* out_gpair) {
CHECK(param_.ndcg_exp_gain) << "NDCG gain can not be set for the MAP objective."; CHECK(param_.ndcg_exp_gain) << "NDCG gain can not be set for the MAP objective.";
if (ctx_->IsCUDA()) { if (ctx_->IsCUDA()) {
return cuda_impl::LambdaRankGetGradientMAP( return cuda_impl::LambdaRankGetGradientMAP(
@ -482,8 +485,11 @@ class LambdaRankMAP : public LambdaRankObj<LambdaRankMAP, ltr::MAPCache> {
auto gptr = p_cache_->DataGroupPtr(ctx_).data(); auto gptr = p_cache_->DataGroupPtr(ctx_).data();
bst_group_t n_groups = p_cache_->Groups(); bst_group_t n_groups = p_cache_->Groups();
out_gpair->Resize(info.num_row_); CHECK_EQ(info.labels.Shape(1), 1) << "multi-target for learning to rank is not yet supported.";
auto h_gpair = out_gpair->HostSpan(); out_gpair->SetDevice(ctx_->Device());
out_gpair->Reshape(info.num_row_, this->Targets(info));
auto h_gpair = out_gpair->HostView();
auto h_label = info.labels.HostView().Slice(linalg::All(), 0); auto h_label = info.labels.HostView().Slice(linalg::All(), 0);
auto h_predt = predt.ConstHostSpan(); auto h_predt = predt.ConstHostSpan();
auto rank_idx = p_cache_->SortedIdx(ctx_, h_predt); auto rank_idx = p_cache_->SortedIdx(ctx_, h_predt);
@ -514,7 +520,7 @@ class LambdaRankMAP : public LambdaRankObj<LambdaRankMAP, ltr::MAPCache> {
auto cnt = gptr[g + 1] - gptr[g]; auto cnt = gptr[g + 1] - gptr[g];
auto w = h_weight[g]; auto w = h_weight[g];
auto g_predt = h_predt.subspan(gptr[g], cnt); auto g_predt = h_predt.subspan(gptr[g], cnt);
auto g_gpair = h_gpair.subspan(gptr[g], cnt); auto g_gpair = h_gpair.Slice(linalg::Range(gptr[g], gptr[g] + cnt), 0);
auto g_label = h_label.Slice(make_range(g)); auto g_label = h_label.Slice(make_range(g));
auto g_rank = rank_idx.subspan(gptr[g], cnt); auto g_rank = rank_idx.subspan(gptr[g], cnt);
@ -545,7 +551,7 @@ void LambdaRankGetGradientMAP(Context const*, std::int32_t, HostDeviceVector<flo
linalg::VectorView<double const>, // input bias ratio linalg::VectorView<double const>, // input bias ratio
linalg::VectorView<double const>, // input bias ratio linalg::VectorView<double const>, // input bias ratio
linalg::VectorView<double>, linalg::VectorView<double>, linalg::VectorView<double>, linalg::VectorView<double>,
HostDeviceVector<GradientPair>*) { linalg::Matrix<GradientPair>*) {
common::AssertGPUSupport(); common::AssertGPUSupport();
} }
} // namespace cuda_impl } // namespace cuda_impl
@ -557,7 +563,7 @@ void LambdaRankGetGradientMAP(Context const*, std::int32_t, HostDeviceVector<flo
class LambdaRankPairwise : public LambdaRankObj<LambdaRankPairwise, ltr::RankingCache> { class LambdaRankPairwise : public LambdaRankObj<LambdaRankPairwise, ltr::RankingCache> {
public: public:
void GetGradientImpl(std::int32_t iter, const HostDeviceVector<float>& predt, void GetGradientImpl(std::int32_t iter, const HostDeviceVector<float>& predt,
const MetaInfo& info, HostDeviceVector<GradientPair>* out_gpair) { const MetaInfo& info, linalg::Matrix<GradientPair>* out_gpair) {
CHECK(param_.ndcg_exp_gain) << "NDCG gain can not be set for the pairwise objective."; CHECK(param_.ndcg_exp_gain) << "NDCG gain can not be set for the pairwise objective.";
if (ctx_->IsCUDA()) { if (ctx_->IsCUDA()) {
return cuda_impl::LambdaRankGetGradientPairwise( return cuda_impl::LambdaRankGetGradientPairwise(
@ -569,8 +575,10 @@ class LambdaRankPairwise : public LambdaRankObj<LambdaRankPairwise, ltr::Ranking
auto gptr = p_cache_->DataGroupPtr(ctx_); auto gptr = p_cache_->DataGroupPtr(ctx_);
bst_group_t n_groups = p_cache_->Groups(); bst_group_t n_groups = p_cache_->Groups();
out_gpair->Resize(info.num_row_); out_gpair->SetDevice(ctx_->Device());
auto h_gpair = out_gpair->HostSpan(); out_gpair->Reshape(info.num_row_, this->Targets(info));
auto h_gpair = out_gpair->HostView();
auto h_label = info.labels.HostView().Slice(linalg::All(), 0); auto h_label = info.labels.HostView().Slice(linalg::All(), 0);
auto h_predt = predt.ConstHostSpan(); auto h_predt = predt.ConstHostSpan();
auto h_weight = common::MakeOptionalWeights(ctx_, info.weights_); auto h_weight = common::MakeOptionalWeights(ctx_, info.weights_);
@ -585,7 +593,7 @@ class LambdaRankPairwise : public LambdaRankObj<LambdaRankPairwise, ltr::Ranking
auto cnt = gptr[g + 1] - gptr[g]; auto cnt = gptr[g + 1] - gptr[g];
auto w = h_weight[g]; auto w = h_weight[g];
auto g_predt = h_predt.subspan(gptr[g], cnt); auto g_predt = h_predt.subspan(gptr[g], cnt);
auto g_gpair = h_gpair.subspan(gptr[g], cnt); auto g_gpair = h_gpair.Slice(linalg::Range(gptr[g], gptr[g] + cnt), 0);
auto g_label = h_label.Slice(make_range(g)); auto g_label = h_label.Slice(make_range(g));
auto g_rank = rank_idx.subspan(gptr[g], cnt); auto g_rank = rank_idx.subspan(gptr[g], cnt);
@ -611,7 +619,7 @@ void LambdaRankGetGradientPairwise(Context const*, std::int32_t, HostDeviceVecto
linalg::VectorView<double const>, // input bias ratio linalg::VectorView<double const>, // input bias ratio
linalg::VectorView<double const>, // input bias ratio linalg::VectorView<double const>, // input bias ratio
linalg::VectorView<double>, linalg::VectorView<double>, linalg::VectorView<double>, linalg::VectorView<double>,
HostDeviceVector<GradientPair>*) { linalg::Matrix<GradientPair>*) {
common::AssertGPUSupport(); common::AssertGPUSupport();
} }
} // namespace cuda_impl } // namespace cuda_impl

View File

@ -93,7 +93,7 @@ struct GetGradOp {
// obtain group segment data. // obtain group segment data.
auto g_label = args.labels.Slice(linalg::Range(data_group_begin, data_group_begin + n_data), 0); auto g_label = args.labels.Slice(linalg::Range(data_group_begin, data_group_begin + n_data), 0);
auto g_predt = args.predts.subspan(data_group_begin, n_data); auto g_predt = args.predts.subspan(data_group_begin, n_data);
auto g_gpair = args.gpairs.subspan(data_group_begin, n_data).data(); auto g_gpair = args.gpairs.Slice(linalg::Range(data_group_begin, data_group_begin + n_data));
auto g_rank = args.d_sorted_idx.subspan(data_group_begin, n_data); auto g_rank = args.d_sorted_idx.subspan(data_group_begin, n_data);
auto [i, j] = make_pair(idx, g); auto [i, j] = make_pair(idx, g);
@ -128,8 +128,8 @@ struct GetGradOp {
auto ngt = GradientPair{common::TruncateWithRounding(gr.GetGrad(), ng.GetGrad()), auto ngt = GradientPair{common::TruncateWithRounding(gr.GetGrad(), ng.GetGrad()),
common::TruncateWithRounding(gr.GetHess(), ng.GetHess())}; common::TruncateWithRounding(gr.GetHess(), ng.GetHess())};
dh::AtomicAddGpair(g_gpair + idx_high, pgt); dh::AtomicAddGpair(&g_gpair(idx_high), pgt);
dh::AtomicAddGpair(g_gpair + idx_low, ngt); dh::AtomicAddGpair(&g_gpair(idx_low), ngt);
} }
if (unbiased && need_update) { if (unbiased && need_update) {
@ -266,16 +266,16 @@ void CalcGrad(Context const* ctx, MetaInfo const& info, std::shared_ptr<ltr::Ran
*/ */
auto d_weights = common::MakeOptionalWeights(ctx, info.weights_); auto d_weights = common::MakeOptionalWeights(ctx, info.weights_);
auto w_norm = p_cache->WeightNorm(); auto w_norm = p_cache->WeightNorm();
thrust::for_each_n(ctx->CUDACtx()->CTP(), thrust::make_counting_iterator(0ul), d_gpair.size(), thrust::for_each_n(ctx->CUDACtx()->CTP(), thrust::make_counting_iterator(0ul), d_gpair.Size(),
[=] XGBOOST_DEVICE(std::size_t i) { [=] XGBOOST_DEVICE(std::size_t i) mutable {
auto g = dh::SegmentId(d_gptr, i); auto g = dh::SegmentId(d_gptr, i);
auto sum_lambda = thrust::get<2>(d_max_lambdas[g]); auto sum_lambda = thrust::get<2>(d_max_lambdas[g]);
// Normalization // Normalization
if (sum_lambda > 0.0) { if (sum_lambda > 0.0) {
double norm = std::log2(1.0 + sum_lambda) / sum_lambda; double norm = std::log2(1.0 + sum_lambda) / sum_lambda;
d_gpair[i] *= norm; d_gpair(i, 0) *= norm;
} }
d_gpair[i] *= (d_weights[g] * w_norm); d_gpair(i, 0) *= (d_weights[g] * w_norm);
}); });
} }
@ -288,7 +288,7 @@ void Launch(Context const* ctx, std::int32_t iter, HostDeviceVector<float> const
linalg::VectorView<double const> ti_plus, // input bias ratio linalg::VectorView<double const> ti_plus, // input bias ratio
linalg::VectorView<double const> tj_minus, // input bias ratio linalg::VectorView<double const> tj_minus, // input bias ratio
linalg::VectorView<double> li, linalg::VectorView<double> lj, linalg::VectorView<double> li, linalg::VectorView<double> lj,
HostDeviceVector<GradientPair>* out_gpair) { linalg::Matrix<GradientPair>* out_gpair) {
// boilerplate // boilerplate
std::int32_t device_id = ctx->gpu_id; std::int32_t device_id = ctx->gpu_id;
dh::safe_cuda(cudaSetDevice(device_id)); dh::safe_cuda(cudaSetDevice(device_id));
@ -296,8 +296,8 @@ void Launch(Context const* ctx, std::int32_t iter, HostDeviceVector<float> const
info.labels.SetDevice(device_id); info.labels.SetDevice(device_id);
preds.SetDevice(device_id); preds.SetDevice(device_id);
out_gpair->SetDevice(device_id); out_gpair->SetDevice(ctx->Device());
out_gpair->Resize(preds.Size()); out_gpair->Reshape(preds.Size(), 1);
CHECK(p_cache); CHECK(p_cache);
@ -308,8 +308,9 @@ void Launch(Context const* ctx, std::int32_t iter, HostDeviceVector<float> const
auto label = info.labels.View(ctx->gpu_id); auto label = info.labels.View(ctx->gpu_id);
auto predts = preds.ConstDeviceSpan(); auto predts = preds.ConstDeviceSpan();
auto gpairs = out_gpair->DeviceSpan(); auto gpairs = out_gpair->View(ctx->Device());
thrust::fill_n(ctx->CUDACtx()->CTP(), gpairs.data(), gpairs.size(), GradientPair{0.0f, 0.0f}); thrust::fill_n(ctx->CUDACtx()->CTP(), gpairs.Values().data(), gpairs.Size(),
GradientPair{0.0f, 0.0f});
auto const d_threads_group_ptr = p_cache->CUDAThreadsGroupPtr(); auto const d_threads_group_ptr = p_cache->CUDAThreadsGroupPtr();
auto const d_gptr = p_cache->DataGroupPtr(ctx); auto const d_gptr = p_cache->DataGroupPtr(ctx);
@ -371,7 +372,7 @@ void LambdaRankGetGradientNDCG(Context const* ctx, std::int32_t iter,
linalg::VectorView<double const> ti_plus, // input bias ratio linalg::VectorView<double const> ti_plus, // input bias ratio
linalg::VectorView<double const> tj_minus, // input bias ratio linalg::VectorView<double const> tj_minus, // input bias ratio
linalg::VectorView<double> li, linalg::VectorView<double> lj, linalg::VectorView<double> li, linalg::VectorView<double> lj,
HostDeviceVector<GradientPair>* out_gpair) { linalg::Matrix<GradientPair>* out_gpair) {
// boilerplate // boilerplate
std::int32_t device_id = ctx->gpu_id; std::int32_t device_id = ctx->gpu_id;
dh::safe_cuda(cudaSetDevice(device_id)); dh::safe_cuda(cudaSetDevice(device_id));
@ -440,7 +441,7 @@ void LambdaRankGetGradientMAP(Context const* ctx, std::int32_t iter,
linalg::VectorView<double const> ti_plus, // input bias ratio linalg::VectorView<double const> ti_plus, // input bias ratio
linalg::VectorView<double const> tj_minus, // input bias ratio linalg::VectorView<double const> tj_minus, // input bias ratio
linalg::VectorView<double> li, linalg::VectorView<double> lj, linalg::VectorView<double> li, linalg::VectorView<double> lj,
HostDeviceVector<GradientPair>* out_gpair) { linalg::Matrix<GradientPair>* out_gpair) {
std::int32_t device_id = ctx->gpu_id; std::int32_t device_id = ctx->gpu_id;
dh::safe_cuda(cudaSetDevice(device_id)); dh::safe_cuda(cudaSetDevice(device_id));
@ -479,7 +480,7 @@ void LambdaRankGetGradientPairwise(Context const* ctx, std::int32_t iter,
linalg::VectorView<double const> ti_plus, // input bias ratio linalg::VectorView<double const> ti_plus, // input bias ratio
linalg::VectorView<double const> tj_minus, // input bias ratio linalg::VectorView<double const> tj_minus, // input bias ratio
linalg::VectorView<double> li, linalg::VectorView<double> lj, linalg::VectorView<double> li, linalg::VectorView<double> lj,
HostDeviceVector<GradientPair>* out_gpair) { linalg::Matrix<GradientPair>* out_gpair) {
std::int32_t device_id = ctx->gpu_id; std::int32_t device_id = ctx->gpu_id;
dh::safe_cuda(cudaSetDevice(device_id)); dh::safe_cuda(cudaSetDevice(device_id));

View File

@ -61,7 +61,7 @@ struct KernelInputs {
linalg::MatrixView<float const> labels; linalg::MatrixView<float const> labels;
common::Span<float const> predts; common::Span<float const> predts;
common::Span<GradientPair> gpairs; linalg::MatrixView<GradientPair> gpairs;
linalg::VectorView<GradientPair const> d_roundings; linalg::VectorView<GradientPair const> d_roundings;
double const *d_cost_rounding; double const *d_cost_rounding;
@ -79,8 +79,8 @@ struct MakePairsOp {
/** /**
* \brief Make pair for the topk pair method. * \brief Make pair for the topk pair method.
*/ */
XGBOOST_DEVICE std::tuple<std::size_t, std::size_t> WithTruncation(std::size_t idx, [[nodiscard]] XGBOOST_DEVICE std::tuple<std::size_t, std::size_t> WithTruncation(
bst_group_t g) const { std::size_t idx, bst_group_t g) const {
auto thread_group_begin = args.d_threads_group_ptr[g]; auto thread_group_begin = args.d_threads_group_ptr[g];
auto idx_in_thread_group = idx - thread_group_begin; auto idx_in_thread_group = idx - thread_group_begin;

View File

@ -154,7 +154,7 @@ void LambdaRankGetGradientNDCG(Context const* ctx, std::int32_t iter,
linalg::VectorView<double const> t_plus, // input bias ratio linalg::VectorView<double const> t_plus, // input bias ratio
linalg::VectorView<double const> t_minus, // input bias ratio linalg::VectorView<double const> t_minus, // input bias ratio
linalg::VectorView<double> li, linalg::VectorView<double> lj, linalg::VectorView<double> li, linalg::VectorView<double> lj,
HostDeviceVector<GradientPair>* out_gpair); linalg::Matrix<GradientPair>* out_gpair);
/** /**
* \brief Generate statistic for MAP used for calculating \Delta Z in lambda mart. * \brief Generate statistic for MAP used for calculating \Delta Z in lambda mart.
@ -168,7 +168,7 @@ void LambdaRankGetGradientMAP(Context const* ctx, std::int32_t iter,
linalg::VectorView<double const> t_plus, // input bias ratio linalg::VectorView<double const> t_plus, // input bias ratio
linalg::VectorView<double const> t_minus, // input bias ratio linalg::VectorView<double const> t_minus, // input bias ratio
linalg::VectorView<double> li, linalg::VectorView<double> lj, linalg::VectorView<double> li, linalg::VectorView<double> lj,
HostDeviceVector<GradientPair>* out_gpair); linalg::Matrix<GradientPair>* out_gpair);
void LambdaRankGetGradientPairwise(Context const* ctx, std::int32_t iter, void LambdaRankGetGradientPairwise(Context const* ctx, std::int32_t iter,
HostDeviceVector<float> const& predt, const MetaInfo& info, HostDeviceVector<float> const& predt, const MetaInfo& info,
@ -176,7 +176,7 @@ void LambdaRankGetGradientPairwise(Context const* ctx, std::int32_t iter,
linalg::VectorView<double const> ti_plus, // input bias ratio linalg::VectorView<double const> ti_plus, // input bias ratio
linalg::VectorView<double const> tj_minus, // input bias ratio linalg::VectorView<double const> tj_minus, // input bias ratio
linalg::VectorView<double> li, linalg::VectorView<double> lj, linalg::VectorView<double> li, linalg::VectorView<double> lj,
HostDeviceVector<GradientPair>* out_gpair); linalg::Matrix<GradientPair>* out_gpair);
void LambdaRankUpdatePositionBias(Context const* ctx, linalg::VectorView<double const> li_full, void LambdaRankUpdatePositionBias(Context const* ctx, linalg::VectorView<double const> li_full,
linalg::VectorView<double const> lj_full, linalg::VectorView<double const> lj_full,

View File

@ -1,5 +1,5 @@
/*! /**
* Copyright 2015-2022 by XGBoost Contributors * Copyright 2015-2023, XGBoost Contributors
* \file multi_class.cc * \file multi_class.cc
* \brief Definition of multi-class classification objectives. * \brief Definition of multi-class classification objectives.
* \author Tianqi Chen * \author Tianqi Chen
@ -48,13 +48,8 @@ class SoftmaxMultiClassObj : public ObjFunction {
ObjInfo Task() const override { return ObjInfo::kClassification; } ObjInfo Task() const override { return ObjInfo::kClassification; }
void GetGradient(const HostDeviceVector<bst_float>& preds, void GetGradient(const HostDeviceVector<bst_float>& preds, const MetaInfo& info, std::int32_t,
const MetaInfo& info, linalg::Matrix<GradientPair>* out_gpair) override {
int iter,
HostDeviceVector<GradientPair>* out_gpair) override {
// Remove unused parameter compiler warning.
(void) iter;
if (info.labels.Size() == 0) { if (info.labels.Size() == 0) {
return; return;
} }
@ -77,7 +72,7 @@ class SoftmaxMultiClassObj : public ObjFunction {
label_correct_.Resize(1); label_correct_.Resize(1);
label_correct_.SetDevice(device); label_correct_.SetDevice(device);
out_gpair->Resize(preds.Size()); out_gpair->Reshape(info.num_row_, static_cast<std::uint64_t>(nclass));
label_correct_.Fill(1); label_correct_.Fill(1);
const bool is_null_weight = info.weights_.Size() == 0; const bool is_null_weight = info.weights_.Size() == 0;
@ -115,7 +110,7 @@ class SoftmaxMultiClassObj : public ObjFunction {
gpair[idx * nclass + k] = GradientPair(p * wt, h); gpair[idx * nclass + k] = GradientPair(p * wt, h);
} }
}, common::Range{0, ndata}, ctx_->Threads(), device) }, common::Range{0, ndata}, ctx_->Threads(), device)
.Eval(out_gpair, info.labels.Data(), &preds, &info.weights_, &label_correct_); .Eval(out_gpair->Data(), info.labels.Data(), &preds, &info.weights_, &label_correct_);
std::vector<int>& label_correct_h = label_correct_.HostVector(); std::vector<int>& label_correct_h = label_correct_.HostVector();
for (auto const flag : label_correct_h) { for (auto const flag : label_correct_h) {

View File

@ -27,13 +27,12 @@
#endif // defined(XGBOOST_USE_CUDA) #endif // defined(XGBOOST_USE_CUDA)
namespace xgboost { namespace xgboost::obj {
namespace obj {
class QuantileRegression : public ObjFunction { class QuantileRegression : public ObjFunction {
common::QuantileLossParam param_; common::QuantileLossParam param_;
HostDeviceVector<float> alpha_; HostDeviceVector<float> alpha_;
bst_target_t Targets(MetaInfo const& info) const override { [[nodiscard]] bst_target_t Targets(MetaInfo const& info) const override {
auto const& alpha = param_.quantile_alpha.Get(); auto const& alpha = param_.quantile_alpha.Get();
CHECK_EQ(alpha.size(), alpha_.Size()) << "The objective is not yet configured."; CHECK_EQ(alpha.size(), alpha_.Size()) << "The objective is not yet configured.";
if (info.ShouldHaveLabels()) { if (info.ShouldHaveLabels()) {
@ -50,7 +49,7 @@ class QuantileRegression : public ObjFunction {
public: public:
void GetGradient(HostDeviceVector<float> const& preds, const MetaInfo& info, std::int32_t iter, void GetGradient(HostDeviceVector<float> const& preds, const MetaInfo& info, std::int32_t iter,
HostDeviceVector<GradientPair>* out_gpair) override { linalg::Matrix<GradientPair>* out_gpair) override {
if (iter == 0) { if (iter == 0) {
CheckInitInputs(info); CheckInitInputs(info);
} }
@ -65,10 +64,11 @@ class QuantileRegression : public ObjFunction {
auto labels = info.labels.View(ctx_->gpu_id); auto labels = info.labels.View(ctx_->gpu_id);
out_gpair->SetDevice(ctx_->gpu_id); out_gpair->SetDevice(ctx_->Device());
out_gpair->Resize(n_targets * info.num_row_); CHECK_EQ(info.labels.Shape(1), 1)
auto gpair = << "Multi-target for quantile regression is not yet supported.";
linalg::MakeTensorView(ctx_, out_gpair, info.num_row_, n_alphas, n_targets / n_alphas); out_gpair->Reshape(info.num_row_, n_targets);
auto gpair = out_gpair->View(ctx_->Device());
info.weights_.SetDevice(ctx_->gpu_id); info.weights_.SetDevice(ctx_->gpu_id);
common::OptionalWeights weight{ctx_->IsCPU() ? info.weights_.ConstHostSpan() common::OptionalWeights weight{ctx_->IsCPU() ? info.weights_.ConstHostSpan()
@ -85,15 +85,16 @@ class QuantileRegression : public ObjFunction {
ctx_, gpair, [=] XGBOOST_DEVICE(std::size_t i, GradientPair const&) mutable { ctx_, gpair, [=] XGBOOST_DEVICE(std::size_t i, GradientPair const&) mutable {
auto [sample_id, quantile_id, target_id] = auto [sample_id, quantile_id, target_id] =
linalg::UnravelIndex(i, n_samples, alpha.size(), n_targets / alpha.size()); linalg::UnravelIndex(i, n_samples, alpha.size(), n_targets / alpha.size());
assert(target_id == 0);
auto d = predt(i) - labels(sample_id, target_id); auto d = predt(i) - labels(sample_id, target_id);
auto h = weight[sample_id]; auto h = weight[sample_id];
if (d >= 0) { if (d >= 0) {
auto g = (1.0f - alpha[quantile_id]) * weight[sample_id]; auto g = (1.0f - alpha[quantile_id]) * weight[sample_id];
gpair(sample_id, quantile_id, target_id) = GradientPair{g, h}; gpair(sample_id, quantile_id) = GradientPair{g, h};
} else { } else {
auto g = (-alpha[quantile_id] * weight[sample_id]); auto g = (-alpha[quantile_id] * weight[sample_id]);
gpair(sample_id, quantile_id, target_id) = GradientPair{g, h}; gpair(sample_id, quantile_id) = GradientPair{g, h};
} }
}); });
} }
@ -192,7 +193,7 @@ class QuantileRegression : public ObjFunction {
param_.Validate(); param_.Validate();
this->alpha_.HostVector() = param_.quantile_alpha.Get(); this->alpha_.HostVector() = param_.quantile_alpha.Get();
} }
ObjInfo Task() const override { return {ObjInfo::kRegression, true, true}; } [[nodiscard]] ObjInfo Task() const override { return {ObjInfo::kRegression, true, true}; }
static char const* Name() { return "reg:quantileerror"; } static char const* Name() { return "reg:quantileerror"; }
void SaveConfig(Json* p_out) const override { void SaveConfig(Json* p_out) const override {
@ -206,8 +207,8 @@ class QuantileRegression : public ObjFunction {
alpha_.HostVector() = param_.quantile_alpha.Get(); alpha_.HostVector() = param_.quantile_alpha.Get();
} }
const char* DefaultEvalMetric() const override { return "quantile"; } [[nodiscard]] const char* DefaultEvalMetric() const override { return "quantile"; }
Json DefaultMetricConfig() const override { [[nodiscard]] Json DefaultMetricConfig() const override {
CHECK(param_.GetInitialised()); CHECK(param_.GetInitialised());
Json config{Object{}}; Json config{Object{}};
config["name"] = String{this->DefaultEvalMetric()}; config["name"] = String{this->DefaultEvalMetric()};
@ -223,5 +224,4 @@ XGBOOST_REGISTER_OBJECTIVE(QuantileRegression, QuantileRegression::Name())
#if defined(XGBOOST_USE_CUDA) #if defined(XGBOOST_USE_CUDA)
DMLC_REGISTRY_FILE_TAG(quantile_obj_gpu); DMLC_REGISTRY_FILE_TAG(quantile_obj_gpu);
#endif // defined(XGBOOST_USE_CUDA) #endif // defined(XGBOOST_USE_CUDA)
} // namespace obj } // namespace xgboost::obj
} // namespace xgboost

View File

@ -36,12 +36,12 @@
#include "xgboost/tree_model.h" // RegTree #include "xgboost/tree_model.h" // RegTree
#if defined(XGBOOST_USE_CUDA) #if defined(XGBOOST_USE_CUDA)
#include "../common/cuda_context.cuh" // for CUDAContext
#include "../common/device_helpers.cuh" #include "../common/device_helpers.cuh"
#include "../common/linalg_op.cuh" #include "../common/linalg_op.cuh"
#endif // defined(XGBOOST_USE_CUDA) #endif // defined(XGBOOST_USE_CUDA)
namespace xgboost { namespace xgboost::obj {
namespace obj {
namespace { namespace {
void CheckRegInputs(MetaInfo const& info, HostDeviceVector<bst_float> const& preds) { void CheckRegInputs(MetaInfo const& info, HostDeviceVector<bst_float> const& preds) {
CheckInitInputs(info); CheckInitInputs(info);
@ -68,33 +68,60 @@ class RegLossObj : public FitIntercept {
HostDeviceVector<float> additional_input_; HostDeviceVector<float> additional_input_;
public: public:
// 0 - label_correct flag, 1 - scale_pos_weight, 2 - is_null_weight void ValidateLabel(MetaInfo const& info) {
RegLossObj(): additional_input_(3) {} auto label = info.labels.View(ctx_->Ordinal());
auto valid = ctx_->DispatchDevice(
[&] {
return std::all_of(linalg::cbegin(label), linalg::cend(label),
[](float y) -> bool { return Loss::CheckLabel(y); });
},
[&] {
#if defined(XGBOOST_USE_CUDA)
auto cuctx = ctx_->CUDACtx();
auto it = dh::MakeTransformIterator<bool>(
thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(std::size_t i) -> bool {
auto [m, n] = linalg::UnravelIndex(i, label.Shape());
return Loss::CheckLabel(label(m, n));
});
return dh::Reduce(cuctx->CTP(), it, it + label.Size(), true, thrust::logical_and<>{});
#else
common::AssertGPUSupport();
return false;
#endif // defined(XGBOOST_USE_CUDA)
});
if (!valid) {
LOG(FATAL) << Loss::LabelErrorMsg();
}
}
// 0 - scale_pos_weight, 1 - is_null_weight
RegLossObj(): additional_input_(2) {}
void Configure(const std::vector<std::pair<std::string, std::string> >& args) override { void Configure(const std::vector<std::pair<std::string, std::string> >& args) override {
param_.UpdateAllowUnknown(args); param_.UpdateAllowUnknown(args);
} }
ObjInfo Task() const override { return Loss::Info(); } [[nodiscard]] ObjInfo Task() const override { return Loss::Info(); }
bst_target_t Targets(MetaInfo const& info) const override { [[nodiscard]] bst_target_t Targets(MetaInfo const& info) const override {
// Multi-target regression. // Multi-target regression.
return std::max(static_cast<size_t>(1), info.labels.Shape(1)); return std::max(static_cast<std::size_t>(1), info.labels.Shape(1));
} }
void GetGradient(const HostDeviceVector<bst_float>& preds, void GetGradient(const HostDeviceVector<bst_float>& preds, const MetaInfo& info,
const MetaInfo &info, int, std::int32_t iter, linalg::Matrix<GradientPair>* out_gpair) override {
HostDeviceVector<GradientPair>* out_gpair) override {
CheckRegInputs(info, preds); CheckRegInputs(info, preds);
if (iter == 0) {
ValidateLabel(info);
}
size_t const ndata = preds.Size(); size_t const ndata = preds.Size();
out_gpair->Resize(ndata); out_gpair->SetDevice(ctx_->Device());
auto device = ctx_->gpu_id; auto device = ctx_->gpu_id;
additional_input_.HostVector().begin()[0] = 1; // Fill the label_correct flag
bool is_null_weight = info.weights_.Size() == 0; bool is_null_weight = info.weights_.Size() == 0;
auto scale_pos_weight = param_.scale_pos_weight; auto scale_pos_weight = param_.scale_pos_weight;
additional_input_.HostVector().begin()[1] = scale_pos_weight; additional_input_.HostVector().begin()[0] = scale_pos_weight;
additional_input_.HostVector().begin()[2] = is_null_weight; additional_input_.HostVector().begin()[1] = is_null_weight;
const size_t nthreads = ctx_->Threads(); const size_t nthreads = ctx_->Threads();
bool on_device = device >= 0; bool on_device = device >= 0;
@ -102,7 +129,8 @@ class RegLossObj : public FitIntercept {
// for better performance. // for better performance.
const size_t n_data_blocks = std::max(static_cast<size_t>(1), (on_device ? ndata : nthreads)); const size_t n_data_blocks = std::max(static_cast<size_t>(1), (on_device ? ndata : nthreads));
const size_t block_size = ndata / n_data_blocks + !!(ndata % n_data_blocks); const size_t block_size = ndata / n_data_blocks + !!(ndata % n_data_blocks);
auto const n_targets = std::max(info.labels.Shape(1), static_cast<size_t>(1)); auto const n_targets = this->Targets(info);
out_gpair->Reshape(info.num_row_, n_targets);
common::Transform<>::Init( common::Transform<>::Init(
[block_size, ndata, n_targets] XGBOOST_DEVICE( [block_size, ndata, n_targets] XGBOOST_DEVICE(
@ -117,8 +145,8 @@ class RegLossObj : public FitIntercept {
GradientPair* out_gpair_ptr = _out_gpair.data(); GradientPair* out_gpair_ptr = _out_gpair.data();
const size_t begin = data_block_idx*block_size; const size_t begin = data_block_idx*block_size;
const size_t end = std::min(ndata, begin + block_size); const size_t end = std::min(ndata, begin + block_size);
const float _scale_pos_weight = _additional_input[1]; const float _scale_pos_weight = _additional_input[0];
const bool _is_null_weight = _additional_input[2]; const bool _is_null_weight = _additional_input[1];
for (size_t idx = begin; idx < end; ++idx) { for (size_t idx = begin; idx < end; ++idx) {
bst_float p = Loss::PredTransform(preds_ptr[idx]); bst_float p = Loss::PredTransform(preds_ptr[idx]);
@ -127,26 +155,17 @@ class RegLossObj : public FitIntercept {
if (label == 1.0f) { if (label == 1.0f) {
w *= _scale_pos_weight; w *= _scale_pos_weight;
} }
if (!Loss::CheckLabel(label)) {
// If there is an incorrect label, the host code will know.
_additional_input[0] = 0;
}
out_gpair_ptr[idx] = GradientPair(Loss::FirstOrderGradient(p, label) * w, out_gpair_ptr[idx] = GradientPair(Loss::FirstOrderGradient(p, label) * w,
Loss::SecondOrderGradient(p, label) * w); Loss::SecondOrderGradient(p, label) * w);
} }
}, },
common::Range{0, static_cast<int64_t>(n_data_blocks)}, nthreads, device) common::Range{0, static_cast<int64_t>(n_data_blocks)}, nthreads, device)
.Eval(&additional_input_, out_gpair, &preds, info.labels.Data(), .Eval(&additional_input_, out_gpair->Data(), &preds, info.labels.Data(),
&info.weights_); &info.weights_);
auto const flag = additional_input_.HostVector().begin()[0];
if (flag == 0) {
LOG(FATAL) << Loss::LabelErrorMsg();
}
} }
public: public:
const char* DefaultEvalMetric() const override { [[nodiscard]] const char* DefaultEvalMetric() const override {
return Loss::DefaultEvalMetric(); return Loss::DefaultEvalMetric();
} }
@ -160,7 +179,7 @@ class RegLossObj : public FitIntercept {
.Eval(io_preds); .Eval(io_preds);
} }
float ProbToMargin(float base_score) const override { [[nodiscard]] float ProbToMargin(float base_score) const override {
return Loss::ProbToMargin(base_score); return Loss::ProbToMargin(base_score);
} }
@ -215,21 +234,21 @@ class PseudoHuberRegression : public FitIntercept {
public: public:
void Configure(Args const& args) override { param_.UpdateAllowUnknown(args); } void Configure(Args const& args) override { param_.UpdateAllowUnknown(args); }
ObjInfo Task() const override { return ObjInfo::kRegression; } [[nodiscard]] ObjInfo Task() const override { return ObjInfo::kRegression; }
bst_target_t Targets(MetaInfo const& info) const override { [[nodiscard]] bst_target_t Targets(MetaInfo const& info) const override {
return std::max(static_cast<size_t>(1), info.labels.Shape(1)); return std::max(static_cast<std::size_t>(1), info.labels.Shape(1));
} }
void GetGradient(HostDeviceVector<bst_float> const& preds, const MetaInfo& info, int /*iter*/, void GetGradient(HostDeviceVector<bst_float> const& preds, const MetaInfo& info, int /*iter*/,
HostDeviceVector<GradientPair>* out_gpair) override { linalg::Matrix<GradientPair>* out_gpair) override {
CheckRegInputs(info, preds); CheckRegInputs(info, preds);
auto slope = param_.huber_slope; auto slope = param_.huber_slope;
CHECK_NE(slope, 0.0) << "slope for pseudo huber cannot be 0."; CHECK_NE(slope, 0.0) << "slope for pseudo huber cannot be 0.";
auto labels = info.labels.View(ctx_->gpu_id); auto labels = info.labels.View(ctx_->gpu_id);
out_gpair->SetDevice(ctx_->gpu_id); out_gpair->SetDevice(ctx_->gpu_id);
out_gpair->Resize(info.labels.Size()); out_gpair->Reshape(info.num_row_, this->Targets(info));
auto gpair = linalg::MakeVec(out_gpair); auto gpair = out_gpair->View(ctx_->Device());
preds.SetDevice(ctx_->gpu_id); preds.SetDevice(ctx_->gpu_id);
auto predt = linalg::MakeVec(&preds); auto predt = linalg::MakeVec(&preds);
@ -252,7 +271,7 @@ class PseudoHuberRegression : public FitIntercept {
}); });
} }
const char* DefaultEvalMetric() const override { return "mphe"; } [[nodiscard]] const char* DefaultEvalMetric() const override { return "mphe"; }
void SaveConfig(Json* p_out) const override { void SaveConfig(Json* p_out) const override {
auto& out = *p_out; auto& out = *p_out;
@ -292,15 +311,15 @@ class PoissonRegression : public FitIntercept {
param_.UpdateAllowUnknown(args); param_.UpdateAllowUnknown(args);
} }
ObjInfo Task() const override { return ObjInfo::kRegression; } [[nodiscard]] ObjInfo Task() const override { return ObjInfo::kRegression; }
void GetGradient(const HostDeviceVector<bst_float>& preds, void GetGradient(const HostDeviceVector<bst_float>& preds, const MetaInfo& info, int,
const MetaInfo &info, int, linalg::Matrix<GradientPair>* out_gpair) override {
HostDeviceVector<GradientPair> *out_gpair) override {
CHECK_NE(info.labels.Size(), 0U) << "label set cannot be empty"; CHECK_NE(info.labels.Size(), 0U) << "label set cannot be empty";
CHECK_EQ(preds.Size(), info.labels.Size()) << "labels are not correctly provided"; CHECK_EQ(preds.Size(), info.labels.Size()) << "labels are not correctly provided";
size_t const ndata = preds.Size(); size_t const ndata = preds.Size();
out_gpair->Resize(ndata); out_gpair->SetDevice(ctx_->Device());
out_gpair->Reshape(info.num_row_, this->Targets(info));
auto device = ctx_->gpu_id; auto device = ctx_->gpu_id;
label_correct_.Resize(1); label_correct_.Resize(1);
label_correct_.Fill(1); label_correct_.Fill(1);
@ -328,7 +347,7 @@ class PoissonRegression : public FitIntercept {
expf(p + max_delta_step) * w}; expf(p + max_delta_step) * w};
}, },
common::Range{0, static_cast<int64_t>(ndata)}, this->ctx_->Threads(), device).Eval( common::Range{0, static_cast<int64_t>(ndata)}, this->ctx_->Threads(), device).Eval(
&label_correct_, out_gpair, &preds, info.labels.Data(), &info.weights_); &label_correct_, out_gpair->Data(), &preds, info.labels.Data(), &info.weights_);
// copy "label correct" flags back to host // copy "label correct" flags back to host
std::vector<int>& label_correct_h = label_correct_.HostVector(); std::vector<int>& label_correct_h = label_correct_.HostVector();
for (auto const flag : label_correct_h) { for (auto const flag : label_correct_h) {
@ -349,10 +368,10 @@ class PoissonRegression : public FitIntercept {
void EvalTransform(HostDeviceVector<bst_float> *io_preds) override { void EvalTransform(HostDeviceVector<bst_float> *io_preds) override {
PredTransform(io_preds); PredTransform(io_preds);
} }
bst_float ProbToMargin(bst_float base_score) const override { [[nodiscard]] float ProbToMargin(bst_float base_score) const override {
return std::log(base_score); return std::log(base_score);
} }
const char* DefaultEvalMetric() const override { [[nodiscard]] const char* DefaultEvalMetric() const override {
return "poisson-nloglik"; return "poisson-nloglik";
} }
@ -383,16 +402,15 @@ XGBOOST_REGISTER_OBJECTIVE(PoissonRegression, "count:poisson")
class CoxRegression : public FitIntercept { class CoxRegression : public FitIntercept {
public: public:
void Configure(Args const&) override {} void Configure(Args const&) override {}
ObjInfo Task() const override { return ObjInfo::kRegression; } [[nodiscard]] ObjInfo Task() const override { return ObjInfo::kRegression; }
void GetGradient(const HostDeviceVector<bst_float>& preds, void GetGradient(const HostDeviceVector<bst_float>& preds, const MetaInfo& info, int,
const MetaInfo &info, int, linalg::Matrix<GradientPair>* out_gpair) override {
HostDeviceVector<GradientPair> *out_gpair) override {
CHECK_NE(info.labels.Size(), 0U) << "label set cannot be empty"; CHECK_NE(info.labels.Size(), 0U) << "label set cannot be empty";
CHECK_EQ(preds.Size(), info.labels.Size()) << "labels are not correctly provided"; CHECK_EQ(preds.Size(), info.labels.Size()) << "labels are not correctly provided";
const auto& preds_h = preds.HostVector(); const auto& preds_h = preds.HostVector();
out_gpair->Resize(preds_h.size()); out_gpair->Reshape(info.num_row_, this->Targets(info));
auto& gpair = out_gpair->HostVector(); auto gpair = out_gpair->HostView();
const std::vector<size_t> &label_order = info.LabelAbsSort(ctx_); const std::vector<size_t> &label_order = info.LabelAbsSort(ctx_);
const omp_ulong ndata = static_cast<omp_ulong>(preds_h.size()); // NOLINT(*) const omp_ulong ndata = static_cast<omp_ulong>(preds_h.size()); // NOLINT(*)
@ -441,7 +459,7 @@ class CoxRegression : public FitIntercept {
const double grad = exp_p*r_k - static_cast<bst_float>(y > 0); const double grad = exp_p*r_k - static_cast<bst_float>(y > 0);
const double hess = exp_p * r_k - exp_p * exp_p * s_k; const double hess = exp_p * r_k - exp_p * exp_p * s_k;
gpair.at(ind) = GradientPair(grad * w, hess * w); gpair(ind) = GradientPair(grad * w, hess * w);
last_abs_y = abs_y; last_abs_y = abs_y;
last_exp_p = exp_p; last_exp_p = exp_p;
@ -457,10 +475,10 @@ class CoxRegression : public FitIntercept {
void EvalTransform(HostDeviceVector<bst_float> *io_preds) override { void EvalTransform(HostDeviceVector<bst_float> *io_preds) override {
PredTransform(io_preds); PredTransform(io_preds);
} }
bst_float ProbToMargin(bst_float base_score) const override { [[nodiscard]] float ProbToMargin(bst_float base_score) const override {
return std::log(base_score); return std::log(base_score);
} }
const char* DefaultEvalMetric() const override { [[nodiscard]] const char* DefaultEvalMetric() const override {
return "cox-nloglik"; return "cox-nloglik";
} }
@ -480,16 +498,16 @@ XGBOOST_REGISTER_OBJECTIVE(CoxRegression, "survival:cox")
class GammaRegression : public FitIntercept { class GammaRegression : public FitIntercept {
public: public:
void Configure(Args const&) override {} void Configure(Args const&) override {}
ObjInfo Task() const override { return ObjInfo::kRegression; } [[nodiscard]] ObjInfo Task() const override { return ObjInfo::kRegression; }
void GetGradient(const HostDeviceVector<bst_float> &preds, void GetGradient(const HostDeviceVector<bst_float>& preds, const MetaInfo& info, std::int32_t,
const MetaInfo &info, int, linalg::Matrix<GradientPair>* out_gpair) override {
HostDeviceVector<GradientPair> *out_gpair) override {
CHECK_NE(info.labels.Size(), 0U) << "label set cannot be empty"; CHECK_NE(info.labels.Size(), 0U) << "label set cannot be empty";
CHECK_EQ(preds.Size(), info.labels.Size()) << "labels are not correctly provided"; CHECK_EQ(preds.Size(), info.labels.Size()) << "labels are not correctly provided";
const size_t ndata = preds.Size(); const size_t ndata = preds.Size();
auto device = ctx_->gpu_id; auto device = ctx_->gpu_id;
out_gpair->Resize(ndata); out_gpair->SetDevice(ctx_->Device());
out_gpair->Reshape(info.num_row_, this->Targets(info));
label_correct_.Resize(1); label_correct_.Resize(1);
label_correct_.Fill(1); label_correct_.Fill(1);
@ -514,7 +532,7 @@ class GammaRegression : public FitIntercept {
_out_gpair[_idx] = GradientPair((1 - y / expf(p)) * w, y / expf(p) * w); _out_gpair[_idx] = GradientPair((1 - y / expf(p)) * w, y / expf(p) * w);
}, },
common::Range{0, static_cast<int64_t>(ndata)}, this->ctx_->Threads(), device).Eval( common::Range{0, static_cast<int64_t>(ndata)}, this->ctx_->Threads(), device).Eval(
&label_correct_, out_gpair, &preds, info.labels.Data(), &info.weights_); &label_correct_, out_gpair->Data(), &preds, info.labels.Data(), &info.weights_);
// copy "label correct" flags back to host // copy "label correct" flags back to host
std::vector<int>& label_correct_h = label_correct_.HostVector(); std::vector<int>& label_correct_h = label_correct_.HostVector();
@ -536,10 +554,10 @@ class GammaRegression : public FitIntercept {
void EvalTransform(HostDeviceVector<bst_float> *io_preds) override { void EvalTransform(HostDeviceVector<bst_float> *io_preds) override {
PredTransform(io_preds); PredTransform(io_preds);
} }
bst_float ProbToMargin(bst_float base_score) const override { [[nodiscard]] float ProbToMargin(bst_float base_score) const override {
return std::log(base_score); return std::log(base_score);
} }
const char* DefaultEvalMetric() const override { [[nodiscard]] const char* DefaultEvalMetric() const override {
return "gamma-nloglik"; return "gamma-nloglik";
} }
void SaveConfig(Json* p_out) const override { void SaveConfig(Json* p_out) const override {
@ -578,15 +596,15 @@ class TweedieRegression : public FitIntercept {
metric_ = os.str(); metric_ = os.str();
} }
ObjInfo Task() const override { return ObjInfo::kRegression; } [[nodiscard]] ObjInfo Task() const override { return ObjInfo::kRegression; }
void GetGradient(const HostDeviceVector<bst_float>& preds, void GetGradient(const HostDeviceVector<bst_float>& preds, const MetaInfo& info, std::int32_t,
const MetaInfo &info, int, linalg::Matrix<GradientPair>* out_gpair) override {
HostDeviceVector<GradientPair> *out_gpair) override {
CHECK_NE(info.labels.Size(), 0U) << "label set cannot be empty"; CHECK_NE(info.labels.Size(), 0U) << "label set cannot be empty";
CHECK_EQ(preds.Size(), info.labels.Size()) << "labels are not correctly provided"; CHECK_EQ(preds.Size(), info.labels.Size()) << "labels are not correctly provided";
const size_t ndata = preds.Size(); const size_t ndata = preds.Size();
out_gpair->Resize(ndata); out_gpair->SetDevice(ctx_->Device());
out_gpair->Reshape(info.num_row_, this->Targets(info));
auto device = ctx_->gpu_id; auto device = ctx_->gpu_id;
label_correct_.Resize(1); label_correct_.Resize(1);
@ -619,7 +637,7 @@ class TweedieRegression : public FitIntercept {
_out_gpair[_idx] = GradientPair(grad * w, hess * w); _out_gpair[_idx] = GradientPair(grad * w, hess * w);
}, },
common::Range{0, static_cast<int64_t>(ndata), 1}, this->ctx_->Threads(), device) common::Range{0, static_cast<int64_t>(ndata), 1}, this->ctx_->Threads(), device)
.Eval(&label_correct_, out_gpair, &preds, info.labels.Data(), &info.weights_); .Eval(&label_correct_, out_gpair->Data(), &preds, info.labels.Data(), &info.weights_);
// copy "label correct" flags back to host // copy "label correct" flags back to host
std::vector<int>& label_correct_h = label_correct_.HostVector(); std::vector<int>& label_correct_h = label_correct_.HostVector();
@ -639,11 +657,11 @@ class TweedieRegression : public FitIntercept {
.Eval(io_preds); .Eval(io_preds);
} }
bst_float ProbToMargin(bst_float base_score) const override { [[nodiscard]] float ProbToMargin(bst_float base_score) const override {
return std::log(base_score); return std::log(base_score);
} }
const char* DefaultEvalMetric() const override { [[nodiscard]] const char* DefaultEvalMetric() const override {
return metric_.c_str(); return metric_.c_str();
} }
@ -672,19 +690,19 @@ XGBOOST_REGISTER_OBJECTIVE(TweedieRegression, "reg:tweedie")
class MeanAbsoluteError : public ObjFunction { class MeanAbsoluteError : public ObjFunction {
public: public:
void Configure(Args const&) override {} void Configure(Args const&) override {}
ObjInfo Task() const override { return {ObjInfo::kRegression, true, true}; } [[nodiscard]] ObjInfo Task() const override { return {ObjInfo::kRegression, true, true}; }
bst_target_t Targets(MetaInfo const& info) const override { [[nodiscard]] bst_target_t Targets(MetaInfo const& info) const override {
return std::max(static_cast<size_t>(1), info.labels.Shape(1)); return std::max(static_cast<std::size_t>(1), info.labels.Shape(1));
} }
void GetGradient(HostDeviceVector<bst_float> const& preds, const MetaInfo& info, int /*iter*/, void GetGradient(HostDeviceVector<float> const& preds, const MetaInfo& info,
HostDeviceVector<GradientPair>* out_gpair) override { std::int32_t /*iter*/, linalg::Matrix<GradientPair>* out_gpair) override {
CheckRegInputs(info, preds); CheckRegInputs(info, preds);
auto labels = info.labels.View(ctx_->gpu_id); auto labels = info.labels.View(ctx_->gpu_id);
out_gpair->SetDevice(ctx_->gpu_id); out_gpair->SetDevice(ctx_->Device());
out_gpair->Resize(info.labels.Size()); out_gpair->Reshape(info.num_row_, this->Targets(info));
auto gpair = linalg::MakeVec(out_gpair); auto gpair = out_gpair->View(ctx_->Device());
preds.SetDevice(ctx_->gpu_id); preds.SetDevice(ctx_->gpu_id);
auto predt = linalg::MakeVec(&preds); auto predt = linalg::MakeVec(&preds);
@ -692,14 +710,14 @@ class MeanAbsoluteError : public ObjFunction {
common::OptionalWeights weight{ctx_->IsCPU() ? info.weights_.ConstHostSpan() common::OptionalWeights weight{ctx_->IsCPU() ? info.weights_.ConstHostSpan()
: info.weights_.ConstDeviceSpan()}; : info.weights_.ConstDeviceSpan()};
linalg::ElementWiseKernel(ctx_, labels, [=] XGBOOST_DEVICE(size_t i, float const y) mutable { linalg::ElementWiseKernel(ctx_, labels, [=] XGBOOST_DEVICE(std::size_t i, float y) mutable {
auto sign = [](auto x) { auto sign = [](auto x) {
return (x > static_cast<decltype(x)>(0)) - (x < static_cast<decltype(x)>(0)); return (x > static_cast<decltype(x)>(0)) - (x < static_cast<decltype(x)>(0));
}; };
auto sample_id = std::get<0>(linalg::UnravelIndex(i, labels.Shape())); auto [sample_id, target_id] = linalg::UnravelIndex(i, labels.Shape());
auto grad = sign(predt(i) - y) * weight[sample_id]; auto grad = sign(predt(i) - y) * weight[sample_id];
auto hess = weight[sample_id]; auto hess = weight[sample_id];
gpair(i) = GradientPair{grad, hess}; gpair(sample_id, target_id) = GradientPair{grad, hess};
}); });
} }
@ -748,7 +766,7 @@ class MeanAbsoluteError : public ObjFunction {
p_tree); p_tree);
} }
const char* DefaultEvalMetric() const override { return "mae"; } [[nodiscard]] const char* DefaultEvalMetric() const override { return "mae"; }
void SaveConfig(Json* p_out) const override { void SaveConfig(Json* p_out) const override {
auto& out = *p_out; auto& out = *p_out;
@ -763,5 +781,4 @@ class MeanAbsoluteError : public ObjFunction {
XGBOOST_REGISTER_OBJECTIVE(MeanAbsoluteError, "reg:absoluteerror") XGBOOST_REGISTER_OBJECTIVE(MeanAbsoluteError, "reg:absoluteerror")
.describe("Mean absoluate error.") .describe("Mean absoluate error.")
.set_body([]() { return new MeanAbsoluteError(); }); .set_body([]() { return new MeanAbsoluteError(); });
} // namespace obj } // namespace xgboost::obj
} // namespace xgboost

View File

@ -66,14 +66,13 @@ inline void FitStump(Context const*, linalg::TensorView<GradientPair const, 2>,
#endif // !defined(XGBOOST_USE_CUDA) #endif // !defined(XGBOOST_USE_CUDA)
} // namespace cuda_impl } // namespace cuda_impl
void FitStump(Context const* ctx, MetaInfo const& info, HostDeviceVector<GradientPair> const& gpair, void FitStump(Context const* ctx, MetaInfo const& info, linalg::Matrix<GradientPair> const& gpair,
bst_target_t n_targets, linalg::Vector<float>* out) { bst_target_t n_targets, linalg::Vector<float>* out) {
out->SetDevice(ctx->gpu_id); out->SetDevice(ctx->gpu_id);
out->Reshape(n_targets); out->Reshape(n_targets);
auto n_samples = gpair.Size() / n_targets;
gpair.SetDevice(ctx->gpu_id); gpair.SetDevice(ctx->Device());
auto gpair_t = linalg::MakeTensorView(ctx, &gpair, n_samples, n_targets); auto gpair_t = gpair.View(ctx->Device());
ctx->IsCPU() ? cpu_impl::FitStump(ctx, info, gpair_t, out->HostView()) ctx->IsCPU() ? cpu_impl::FitStump(ctx, info, gpair_t, out->HostView())
: cuda_impl::FitStump(ctx, gpair_t, out->View(ctx->gpu_id)); : cuda_impl::FitStump(ctx, gpair_t, out->View(ctx->gpu_id));
} }

View File

@ -31,7 +31,7 @@ XGBOOST_DEVICE inline double CalcUnregularizedWeight(T sum_grad, T sum_hess) {
/** /**
* @brief Fit a tree stump as an estimation of base_score. * @brief Fit a tree stump as an estimation of base_score.
*/ */
void FitStump(Context const* ctx, MetaInfo const& info, HostDeviceVector<GradientPair> const& gpair, void FitStump(Context const* ctx, MetaInfo const& info, linalg::Matrix<GradientPair> const& gpair,
bst_target_t n_targets, linalg::Vector<float>* out); bst_target_t n_targets, linalg::Vector<float>* out);
} // namespace tree } // namespace tree
} // namespace xgboost } // namespace xgboost

View File

@ -269,17 +269,18 @@ class GlobalApproxUpdater : public TreeUpdater {
out["hist_train_param"] = ToJson(hist_param_); out["hist_train_param"] = ToJson(hist_param_);
} }
void InitData(TrainParam const &param, HostDeviceVector<GradientPair> const *gpair, void InitData(TrainParam const &param, linalg::Matrix<GradientPair> const *gpair,
linalg::Matrix<GradientPair> *sampled) { linalg::Matrix<GradientPair> *sampled) {
*sampled = linalg::Empty<GradientPair>(ctx_, gpair->Size(), 1); *sampled = linalg::Empty<GradientPair>(ctx_, gpair->Size(), 1);
sampled->Data()->Copy(*gpair); auto in = gpair->HostView().Values();
std::copy(in.data(), in.data() + in.size(), sampled->HostView().Values().data());
SampleGradient(ctx_, param, sampled->HostView()); SampleGradient(ctx_, param, sampled->HostView());
} }
[[nodiscard]] char const *Name() const override { return "grow_histmaker"; } [[nodiscard]] char const *Name() const override { return "grow_histmaker"; }
void Update(TrainParam const *param, HostDeviceVector<GradientPair> *gpair, DMatrix *m, void Update(TrainParam const *param, linalg::Matrix<GradientPair> *gpair, DMatrix *m,
common::Span<HostDeviceVector<bst_node_t>> out_position, common::Span<HostDeviceVector<bst_node_t>> out_position,
const std::vector<RegTree *> &trees) override { const std::vector<RegTree *> &trees) override {
CHECK(hist_param_.GetInitialised()); CHECK(hist_param_.GetInitialised());

View File

@ -91,7 +91,7 @@ class ColMaker: public TreeUpdater {
} }
} }
void Update(TrainParam const *param, HostDeviceVector<GradientPair> *gpair, DMatrix *dmat, void Update(TrainParam const *param, linalg::Matrix<GradientPair> *gpair, DMatrix *dmat,
common::Span<HostDeviceVector<bst_node_t>> /*out_position*/, common::Span<HostDeviceVector<bst_node_t>> /*out_position*/,
const std::vector<RegTree *> &trees) override { const std::vector<RegTree *> &trees) override {
if (collective::IsDistributed()) { if (collective::IsDistributed()) {
@ -106,10 +106,11 @@ class ColMaker: public TreeUpdater {
// rescale learning rate according to size of trees // rescale learning rate according to size of trees
interaction_constraints_.Configure(*param, dmat->Info().num_row_); interaction_constraints_.Configure(*param, dmat->Info().num_row_);
// build tree // build tree
CHECK_EQ(gpair->Shape(1), 1) << MTNotImplemented();
for (auto tree : trees) { for (auto tree : trees) {
CHECK(ctx_); CHECK(ctx_);
Builder builder(*param, colmaker_param_, interaction_constraints_, ctx_, column_densities_); Builder builder(*param, colmaker_param_, interaction_constraints_, ctx_, column_densities_);
builder.Update(gpair->ConstHostVector(), dmat, tree); builder.Update(gpair->Data()->ConstHostVector(), dmat, tree);
} }
} }

View File

@ -760,16 +760,18 @@ class GPUHistMaker : public TreeUpdater {
dh::GlobalMemoryLogger().Log(); dh::GlobalMemoryLogger().Log();
} }
void Update(TrainParam const* param, HostDeviceVector<GradientPair>* gpair, DMatrix* dmat, void Update(TrainParam const* param, linalg::Matrix<GradientPair>* gpair, DMatrix* dmat,
common::Span<HostDeviceVector<bst_node_t>> out_position, common::Span<HostDeviceVector<bst_node_t>> out_position,
const std::vector<RegTree*>& trees) override { const std::vector<RegTree*>& trees) override {
monitor_.Start("Update"); monitor_.Start("Update");
CHECK_EQ(gpair->Shape(1), 1) << MTNotImplemented();
auto gpair_hdv = gpair->Data();
// build tree // build tree
try { try {
std::size_t t_idx{0}; std::size_t t_idx{0};
for (xgboost::RegTree* tree : trees) { for (xgboost::RegTree* tree : trees) {
this->UpdateTree(param, gpair, dmat, tree, &out_position[t_idx]); this->UpdateTree(param, gpair_hdv, dmat, tree, &out_position[t_idx]);
this->hist_maker_param_.CheckTreesSynchronized(tree); this->hist_maker_param_.CheckTreesSynchronized(tree);
++t_idx; ++t_idx;
} }
@ -887,7 +889,7 @@ class GPUGlobalApproxMaker : public TreeUpdater {
} }
~GPUGlobalApproxMaker() override { dh::GlobalMemoryLogger().Log(); } ~GPUGlobalApproxMaker() override { dh::GlobalMemoryLogger().Log(); }
void Update(TrainParam const* param, HostDeviceVector<GradientPair>* gpair, DMatrix* p_fmat, void Update(TrainParam const* param, linalg::Matrix<GradientPair>* gpair, DMatrix* p_fmat,
common::Span<HostDeviceVector<bst_node_t>> out_position, common::Span<HostDeviceVector<bst_node_t>> out_position,
const std::vector<RegTree*>& trees) override { const std::vector<RegTree*>& trees) override {
monitor_.Start("Update"); monitor_.Start("Update");
@ -898,7 +900,7 @@ class GPUGlobalApproxMaker : public TreeUpdater {
auto hess = dh::ToSpan(hess_); auto hess = dh::ToSpan(hess_);
gpair->SetDevice(ctx_->Device()); gpair->SetDevice(ctx_->Device());
auto d_gpair = gpair->ConstDeviceSpan(); auto d_gpair = gpair->Data()->ConstDeviceSpan();
auto cuctx = ctx_->CUDACtx(); auto cuctx = ctx_->CUDACtx();
thrust::transform(cuctx->CTP(), dh::tcbegin(d_gpair), dh::tcend(d_gpair), dh::tbegin(hess), thrust::transform(cuctx->CTP(), dh::tcbegin(d_gpair), dh::tcend(d_gpair), dh::tbegin(hess),
[=] XGBOOST_DEVICE(GradientPair const& g) { return g.GetHess(); }); [=] XGBOOST_DEVICE(GradientPair const& g) { return g.GetHess(); });
@ -912,7 +914,7 @@ class GPUGlobalApproxMaker : public TreeUpdater {
std::size_t t_idx{0}; std::size_t t_idx{0};
for (xgboost::RegTree* tree : trees) { for (xgboost::RegTree* tree : trees) {
this->UpdateTree(gpair, p_fmat, tree, &out_position[t_idx]); this->UpdateTree(gpair->Data(), p_fmat, tree, &out_position[t_idx]);
this->hist_maker_param_.CheckTreesSynchronized(tree); this->hist_maker_param_.CheckTreesSynchronized(tree);
++t_idx; ++t_idx;
} }

View File

@ -31,7 +31,7 @@ class TreePruner : public TreeUpdater {
[[nodiscard]] bool CanModifyTree() const override { return true; } [[nodiscard]] bool CanModifyTree() const override { return true; }
// update the tree, do pruning // update the tree, do pruning
void Update(TrainParam const* param, HostDeviceVector<GradientPair>* gpair, DMatrix* p_fmat, void Update(TrainParam const* param, linalg::Matrix<GradientPair>* gpair, DMatrix* p_fmat,
common::Span<HostDeviceVector<bst_node_t>> out_position, common::Span<HostDeviceVector<bst_node_t>> out_position,
const std::vector<RegTree*>& trees) override { const std::vector<RegTree*>& trees) override {
pruner_monitor_.Start("PrunerUpdate"); pruner_monitor_.Start("PrunerUpdate");

View File

@ -492,7 +492,7 @@ class QuantileHistMaker : public TreeUpdater {
[[nodiscard]] char const *Name() const override { return "grow_quantile_histmaker"; } [[nodiscard]] char const *Name() const override { return "grow_quantile_histmaker"; }
void Update(TrainParam const *param, HostDeviceVector<GradientPair> *gpair, DMatrix *p_fmat, void Update(TrainParam const *param, linalg::Matrix<GradientPair> *gpair, DMatrix *p_fmat,
common::Span<HostDeviceVector<bst_node_t>> out_position, common::Span<HostDeviceVector<bst_node_t>> out_position,
const std::vector<RegTree *> &trees) override { const std::vector<RegTree *> &trees) override {
if (trees.front()->IsMultiTarget()) { if (trees.front()->IsMultiTarget()) {
@ -511,8 +511,7 @@ class QuantileHistMaker : public TreeUpdater {
} }
bst_target_t n_targets = trees.front()->NumTargets(); bst_target_t n_targets = trees.front()->NumTargets();
auto h_gpair = auto h_gpair = gpair->HostView();
linalg::MakeTensorView(ctx_, gpair->HostSpan(), p_fmat->Info().num_row_, n_targets);
linalg::Matrix<GradientPair> sample_out; linalg::Matrix<GradientPair> sample_out;
auto h_sample_out = h_gpair; auto h_sample_out = h_gpair;

View File

@ -31,11 +31,14 @@ class TreeRefresher : public TreeUpdater {
[[nodiscard]] char const *Name() const override { return "refresh"; } [[nodiscard]] char const *Name() const override { return "refresh"; }
[[nodiscard]] bool CanModifyTree() const override { return true; } [[nodiscard]] bool CanModifyTree() const override { return true; }
// update the tree, do pruning // update the tree, do pruning
void Update(TrainParam const *param, HostDeviceVector<GradientPair> *gpair, DMatrix *p_fmat, void Update(TrainParam const *param, linalg::Matrix<GradientPair> *gpair, DMatrix *p_fmat,
common::Span<HostDeviceVector<bst_node_t>> /*out_position*/, common::Span<HostDeviceVector<bst_node_t>> /*out_position*/,
const std::vector<RegTree *> &trees) override { const std::vector<RegTree *> &trees) override {
if (trees.size() == 0) return; if (trees.size() == 0) {
const std::vector<GradientPair> &gpair_h = gpair->ConstHostVector(); return;
}
CHECK_EQ(gpair->Shape(1), 1) << MTNotImplemented();
const std::vector<GradientPair> &gpair_h = gpair->Data()->ConstHostVector();
// thread temporal space // thread temporal space
std::vector<std::vector<GradStats> > stemp; std::vector<std::vector<GradStats> > stemp;
std::vector<RegTree::FVec> fvec_temp; std::vector<RegTree::FVec> fvec_temp;

View File

@ -31,7 +31,7 @@ class TreeSyncher : public TreeUpdater {
[[nodiscard]] char const* Name() const override { return "prune"; } [[nodiscard]] char const* Name() const override { return "prune"; }
void Update(TrainParam const*, HostDeviceVector<GradientPair>*, DMatrix*, void Update(TrainParam const*, linalg::Matrix<GradientPair>*, DMatrix*,
common::Span<HostDeviceVector<bst_node_t>> /*out_position*/, common::Span<HostDeviceVector<bst_node_t>> /*out_position*/,
const std::vector<RegTree*>& trees) override { const std::vector<RegTree*>& trees) override {
if (collective::GetWorldSize() == 1) return; if (collective::GetWorldSize() == 1) return;

View File

@ -565,7 +565,7 @@ void TestXGDMatrixGetQuantileCut(Context const *ctx) {
ASSERT_EQ(XGBoosterCreate(mats.data(), 1, &booster), 0); ASSERT_EQ(XGBoosterCreate(mats.data(), 1, &booster), 0);
ASSERT_EQ(XGBoosterSetParam(booster, "max_bin", "16"), 0); ASSERT_EQ(XGBoosterSetParam(booster, "max_bin", "16"), 0);
if (ctx->IsCUDA()) { if (ctx->IsCUDA()) {
ASSERT_EQ(XGBoosterSetParam(booster, "tree_method", "gpu_hist"), 0); ASSERT_EQ(XGBoosterSetParam(booster, "device", ctx->DeviceName().c_str()), 0);
} }
ASSERT_EQ(XGBoosterUpdateOneIter(booster, 0, p_fmat), 0); ASSERT_EQ(XGBoosterUpdateOneIter(booster, 0, p_fmat), 0);
ASSERT_EQ(XGDMatrixGetQuantileCut(p_fmat, s_config.c_str(), &out_indptr, &out_data), 0); ASSERT_EQ(XGDMatrixGetQuantileCut(p_fmat, s_config.c_str(), &out_indptr, &out_data), 0);
@ -596,7 +596,7 @@ void TestXGDMatrixGetQuantileCut(Context const *ctx) {
ASSERT_EQ(XGBoosterCreate(mats.data(), 1, &booster), 0); ASSERT_EQ(XGBoosterCreate(mats.data(), 1, &booster), 0);
ASSERT_EQ(XGBoosterSetParam(booster, "max_bin", "16"), 0); ASSERT_EQ(XGBoosterSetParam(booster, "max_bin", "16"), 0);
if (ctx->IsCUDA()) { if (ctx->IsCUDA()) {
ASSERT_EQ(XGBoosterSetParam(booster, "tree_method", "gpu_hist"), 0); ASSERT_EQ(XGBoosterSetParam(booster, "device", ctx->DeviceName().c_str()), 0);
} }
ASSERT_EQ(XGBoosterUpdateOneIter(booster, 0, p_fmat), 0); ASSERT_EQ(XGBoosterUpdateOneIter(booster, 0, p_fmat), 0);
ASSERT_EQ(XGDMatrixGetQuantileCut(p_fmat, s_config.c_str(), &out_indptr, &out_data), 0); ASSERT_EQ(XGDMatrixGetQuantileCut(p_fmat, s_config.c_str(), &out_indptr, &out_data), 0);

View File

@ -65,7 +65,9 @@ TEST(GBTree, PredictionCache) {
gbtree.Configure({{"tree_method", "hist"}}); gbtree.Configure({{"tree_method", "hist"}});
auto p_m = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(); auto p_m = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix();
auto gpair = GenerateRandomGradients(kRows); linalg::Matrix<GradientPair> gpair({kRows}, ctx.Ordinal());
gpair.Data()->Copy(GenerateRandomGradients(kRows));
PredictionCacheEntry out_predictions; PredictionCacheEntry out_predictions;
gbtree.DoBoost(p_m.get(), &gpair, &out_predictions, nullptr); gbtree.DoBoost(p_m.get(), &gpair, &out_predictions, nullptr);
@ -213,7 +215,8 @@ TEST(GBTree, ChooseTreeMethod) {
} }
learner->Configure(); learner->Configure();
for (std::int32_t i = 0; i < 3; ++i) { for (std::int32_t i = 0; i < 3; ++i) {
HostDeviceVector<GradientPair> gpair{GenerateRandomGradients(Xy->Info().num_row_)}; linalg::Matrix<GradientPair> gpair{{Xy->Info().num_row_}, Context::kCpuId};
gpair.Data()->Copy(GenerateRandomGradients(Xy->Info().num_row_));
learner->BoostOneIter(0, Xy, &gpair); learner->BoostOneIter(0, Xy, &gpair);
} }

View File

@ -96,9 +96,9 @@ void CheckObjFunctionImpl(std::unique_ptr<xgboost::ObjFunction> const& obj,
std::vector<xgboost::bst_float> out_grad, std::vector<xgboost::bst_float> out_grad,
std::vector<xgboost::bst_float> out_hess) { std::vector<xgboost::bst_float> out_hess) {
xgboost::HostDeviceVector<xgboost::bst_float> in_preds(preds); xgboost::HostDeviceVector<xgboost::bst_float> in_preds(preds);
xgboost::HostDeviceVector<xgboost::GradientPair> out_gpair; xgboost::linalg::Matrix<xgboost::GradientPair> out_gpair;
obj->GetGradient(in_preds, info, 1, &out_gpair); obj->GetGradient(in_preds, info, 0, &out_gpair);
std::vector<xgboost::GradientPair>& gpair = out_gpair.HostVector(); std::vector<xgboost::GradientPair>& gpair = out_gpair.Data()->HostVector();
ASSERT_EQ(gpair.size(), in_preds.Size()); ASSERT_EQ(gpair.size(), in_preds.Size());
for (int i = 0; i < static_cast<int>(gpair.size()); ++i) { for (int i = 0; i < static_cast<int>(gpair.size()); ++i) {
@ -119,8 +119,8 @@ void CheckObjFunction(std::unique_ptr<xgboost::ObjFunction> const& obj,
std::vector<xgboost::bst_float> out_hess) { std::vector<xgboost::bst_float> out_hess) {
xgboost::MetaInfo info; xgboost::MetaInfo info;
info.num_row_ = labels.size(); info.num_row_ = labels.size();
info.labels = info.labels = xgboost::linalg::Tensor<float, 2>{
xgboost::linalg::Tensor<float, 2>{labels.cbegin(), labels.cend(), {labels.size()}, -1}; labels.cbegin(), labels.cend(), {labels.size(), static_cast<std::size_t>(1)}, -1};
info.weights_.HostVector() = weights; info.weights_.HostVector() = weights;
CheckObjFunctionImpl(obj, preds, labels, weights, info, out_grad, out_hess); CheckObjFunctionImpl(obj, preds, labels, weights, info, out_grad, out_hess);
@ -155,8 +155,8 @@ void CheckRankingObjFunction(std::unique_ptr<xgboost::ObjFunction> const& obj,
std::vector<xgboost::bst_float> out_hess) { std::vector<xgboost::bst_float> out_hess) {
xgboost::MetaInfo info; xgboost::MetaInfo info;
info.num_row_ = labels.size(); info.num_row_ = labels.size();
info.labels = xgboost::linalg::Tensor<float, 2>{ info.labels = xgboost::linalg::Matrix<float>{
labels.cbegin(), labels.cend(), {labels.size(), static_cast<size_t>(1)}, -1}; labels.cbegin(), labels.cend(), {labels.size(), static_cast<std::size_t>(1)}, -1};
info.weights_.HostVector() = weights; info.weights_.HostVector() = weights;
info.group_ptr_ = groups; info.group_ptr_ = groups;
@ -645,11 +645,10 @@ std::unique_ptr<GradientBooster> CreateTrainedGBM(std::string name, Args kwargs,
} }
p_dmat->Info().labels = p_dmat->Info().labels =
linalg::Tensor<float, 2>{labels.cbegin(), labels.cend(), {labels.size()}, -1}; linalg::Tensor<float, 2>{labels.cbegin(), labels.cend(), {labels.size()}, -1};
HostDeviceVector<GradientPair> gpair; linalg::Matrix<GradientPair> gpair({kRows}, ctx->Ordinal());
auto& h_gpair = gpair.HostVector(); auto h_gpair = gpair.HostView();
h_gpair.resize(kRows);
for (size_t i = 0; i < kRows; ++i) { for (size_t i = 0; i < kRows; ++i) {
h_gpair[i] = GradientPair{static_cast<float>(i), 1}; h_gpair(i) = GradientPair{static_cast<float>(i), 1};
} }
PredictionCacheEntry predts; PredictionCacheEntry predts;

View File

@ -387,23 +387,6 @@ std::unique_ptr<GradientBooster> CreateTrainedGBM(std::string name, Args kwargs,
LearnerModelParam const* learner_model_param, LearnerModelParam const* learner_model_param,
Context const* generic_param); Context const* generic_param);
inline std::unique_ptr<HostDeviceVector<GradientPair>> GenerateGradients(
std::size_t rows, bst_target_t n_targets = 1) {
auto p_gradients = std::make_unique<HostDeviceVector<GradientPair>>(rows * n_targets);
auto& h_gradients = p_gradients->HostVector();
xgboost::SimpleLCG gen;
xgboost::SimpleRealUniformDistribution<bst_float> dist(0.0f, 1.0f);
for (std::size_t i = 0; i < rows * n_targets; ++i) {
auto grad = dist(&gen);
auto hess = dist(&gen);
h_gradients[i] = GradientPair{grad, hess};
}
return p_gradients;
}
/** /**
* \brief Make a context that uses CUDA if device >= 0. * \brief Make a context that uses CUDA if device >= 0.
*/ */
@ -415,7 +398,8 @@ inline Context MakeCUDACtx(std::int32_t device) {
} }
inline HostDeviceVector<GradientPair> GenerateRandomGradients(const size_t n_rows, inline HostDeviceVector<GradientPair> GenerateRandomGradients(const size_t n_rows,
float lower= 0.0f, float upper = 1.0f) { float lower = 0.0f,
float upper = 1.0f) {
xgboost::SimpleLCG gen; xgboost::SimpleLCG gen;
xgboost::SimpleRealUniformDistribution<bst_float> dist(lower, upper); xgboost::SimpleRealUniformDistribution<bst_float> dist(lower, upper);
std::vector<GradientPair> h_gpair(n_rows); std::vector<GradientPair> h_gpair(n_rows);
@ -428,6 +412,16 @@ inline HostDeviceVector<GradientPair> GenerateRandomGradients(const size_t n_row
return gpair; return gpair;
} }
inline linalg::Matrix<GradientPair> GenerateRandomGradients(Context const* ctx, bst_row_t n_rows,
bst_target_t n_targets,
float lower = 0.0f,
float upper = 1.0f) {
auto g = GenerateRandomGradients(n_rows * n_targets, lower, upper);
linalg::Matrix<GradientPair> gpair({n_rows, static_cast<bst_row_t>(n_targets)}, ctx->Device());
gpair.Data()->Copy(g);
return gpair;
}
typedef void *DMatrixHandle; // NOLINT(*); typedef void *DMatrixHandle; // NOLINT(*);
class ArrayIterForTest { class ArrayIterForTest {

View File

@ -24,8 +24,8 @@ TEST(Linear, Shotgun) {
auto updater = auto updater =
std::unique_ptr<xgboost::LinearUpdater>(xgboost::LinearUpdater::Create("shotgun", &ctx)); std::unique_ptr<xgboost::LinearUpdater>(xgboost::LinearUpdater::Create("shotgun", &ctx));
updater->Configure({{"eta", "1."}}); updater->Configure({{"eta", "1."}});
xgboost::HostDeviceVector<xgboost::GradientPair> gpair( linalg::Matrix<xgboost::GradientPair> gpair{
p_fmat->Info().num_row_, xgboost::GradientPair(-5, 1.0)); linalg::Constant(&ctx, xgboost::GradientPair(-5, 1.0), p_fmat->Info().num_row_, 1)};
xgboost::gbm::GBLinearModel model{&mparam}; xgboost::gbm::GBLinearModel model{&mparam};
model.LazyInitModel(); model.LazyInitModel();
updater->Update(&gpair, p_fmat.get(), &model, gpair.Size()); updater->Update(&gpair, p_fmat.get(), &model, gpair.Size());
@ -55,8 +55,8 @@ TEST(Linear, coordinate) {
auto updater = std::unique_ptr<xgboost::LinearUpdater>( auto updater = std::unique_ptr<xgboost::LinearUpdater>(
xgboost::LinearUpdater::Create("coord_descent", &ctx)); xgboost::LinearUpdater::Create("coord_descent", &ctx));
updater->Configure({{"eta", "1."}}); updater->Configure({{"eta", "1."}});
xgboost::HostDeviceVector<xgboost::GradientPair> gpair( linalg::Matrix<xgboost::GradientPair> gpair{
p_fmat->Info().num_row_, xgboost::GradientPair(-5, 1.0)); linalg::Constant(&ctx, xgboost::GradientPair(-5, 1.0), p_fmat->Info().num_row_, 1)};
xgboost::gbm::GBLinearModel model{&mparam}; xgboost::gbm::GBLinearModel model{&mparam};
model.LazyInitModel(); model.LazyInitModel();
updater->Update(&gpair, p_fmat.get(), &model, gpair.Size()); updater->Update(&gpair, p_fmat.get(), &model, gpair.Size());

View File

@ -1,4 +1,6 @@
// Copyright by Contributors /**
* Copyright 2018-2023, XGBoost Contributors
*/
#include <xgboost/linear_updater.h> #include <xgboost/linear_updater.h>
#include <xgboost/gbm.h> #include <xgboost/gbm.h>
@ -19,8 +21,7 @@ TEST(Linear, GPUCoordinate) {
auto updater = std::unique_ptr<xgboost::LinearUpdater>( auto updater = std::unique_ptr<xgboost::LinearUpdater>(
xgboost::LinearUpdater::Create("gpu_coord_descent", &ctx)); xgboost::LinearUpdater::Create("gpu_coord_descent", &ctx));
updater->Configure({{"eta", "1."}}); updater->Configure({{"eta", "1."}});
xgboost::HostDeviceVector<xgboost::GradientPair> gpair( auto gpair = linalg::Constant(&ctx, xgboost::GradientPair(-5, 1.0), mat->Info().num_row_, 1);
mat->Info().num_row_, xgboost::GradientPair(-5, 1.0));
xgboost::gbm::GBLinearModel model{&mparam}; xgboost::gbm::GBLinearModel model{&mparam};
model.LazyInitModel(); model.LazyInitModel();

View File

@ -1,5 +1,5 @@
/*! /**
* Copyright (c) by Contributors 2020 * Copyright 2020-2023, XGBoost Contributors
*/ */
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <memory> #include <memory>
@ -12,9 +12,7 @@
#include "../helpers.h" #include "../helpers.h"
#include "../../../src/common/survival_util.h" #include "../../../src/common/survival_util.h"
namespace xgboost { namespace xgboost::common {
namespace common {
TEST(Objective, DeclareUnifiedTest(AFTObjConfiguration)) { TEST(Objective, DeclareUnifiedTest(AFTObjConfiguration)) {
auto ctx = MakeCUDACtx(GPUIDX); auto ctx = MakeCUDACtx(GPUIDX);
std::unique_ptr<ObjFunction> objective(ObjFunction::Create("survival:aft", &ctx)); std::unique_ptr<ObjFunction> objective(ObjFunction::Create("survival:aft", &ctx));
@ -65,14 +63,14 @@ static inline void CheckGPairOverGridPoints(
preds[i] = std::log(std::pow(2.0, i * (log_y_high - log_y_low) / (num_point - 1) + log_y_low)); preds[i] = std::log(std::pow(2.0, i * (log_y_high - log_y_low) / (num_point - 1) + log_y_low));
} }
HostDeviceVector<GradientPair> out_gpair; linalg::Matrix<GradientPair> out_gpair;
obj->GetGradient(HostDeviceVector<bst_float>(preds), info, 1, &out_gpair); obj->GetGradient(HostDeviceVector<bst_float>(preds), info, 1, &out_gpair);
const auto& gpair = out_gpair.HostVector(); const auto gpair = out_gpair.HostView();
CHECK_EQ(num_point, expected_grad.size()); CHECK_EQ(num_point, expected_grad.size());
CHECK_EQ(num_point, expected_hess.size()); CHECK_EQ(num_point, expected_hess.size());
for (int i = 0; i < num_point; ++i) { for (int i = 0; i < num_point; ++i) {
EXPECT_NEAR(gpair[i].GetGrad(), expected_grad[i], ftol); EXPECT_NEAR(gpair(i).GetGrad(), expected_grad[i], ftol);
EXPECT_NEAR(gpair[i].GetHess(), expected_hess[i], ftol); EXPECT_NEAR(gpair(i).GetHess(), expected_hess[i], ftol);
} }
} }
@ -169,5 +167,4 @@ TEST(Objective, DeclareUnifiedTest(AFTObjGPairIntervalCensoredLabels)) {
0.2757f, 0.1776f, 0.1110f, 0.0682f, 0.0415f, 0.0251f, 0.0151f, 0.0091f, 0.0055f, 0.0033f }); 0.2757f, 0.1776f, 0.1110f, 0.0682f, 0.0415f, 0.0251f, 0.0151f, 0.0091f, 0.0055f, 0.0033f });
} }
} // namespace common } // namespace xgboost::common
} // namespace xgboost

View File

@ -74,35 +74,35 @@ void TestNDCGGPair(Context const* ctx) {
info.labels = linalg::Tensor<float, 2>{{0, 1, 0, 1}, {4, 1}, GPUIDX}; info.labels = linalg::Tensor<float, 2>{{0, 1, 0, 1}, {4, 1}, GPUIDX};
info.group_ptr_ = {0, 2, 4}; info.group_ptr_ = {0, 2, 4};
info.num_row_ = 4; info.num_row_ = 4;
HostDeviceVector<GradientPair> gpairs; linalg::Matrix<GradientPair> gpairs;
obj->GetGradient(predts, info, 0, &gpairs); obj->GetGradient(predts, info, 0, &gpairs);
ASSERT_EQ(gpairs.Size(), predts.Size()); ASSERT_EQ(gpairs.Size(), predts.Size());
{ {
predts = {1, 0, 1, 0}; predts = {1, 0, 1, 0};
HostDeviceVector<GradientPair> gpairs; linalg::Matrix<GradientPair> gpairs;
obj->GetGradient(predts, info, 0, &gpairs); obj->GetGradient(predts, info, 0, &gpairs);
for (size_t i = 0; i < gpairs.Size(); ++i) { for (std::size_t i = 0; i < gpairs.Size(); ++i) {
ASSERT_GT(gpairs.HostSpan()[i].GetHess(), 0); ASSERT_GT(gpairs.HostView()(i).GetHess(), 0);
} }
ASSERT_LT(gpairs.HostSpan()[1].GetGrad(), 0); ASSERT_LT(gpairs.HostView()(1).GetGrad(), 0);
ASSERT_LT(gpairs.HostSpan()[3].GetGrad(), 0); ASSERT_LT(gpairs.HostView()(3).GetGrad(), 0);
ASSERT_GT(gpairs.HostSpan()[0].GetGrad(), 0); ASSERT_GT(gpairs.HostView()(0).GetGrad(), 0);
ASSERT_GT(gpairs.HostSpan()[2].GetGrad(), 0); ASSERT_GT(gpairs.HostView()(2).GetGrad(), 0);
info.weights_ = {2, 3}; info.weights_ = {2, 3};
HostDeviceVector<GradientPair> weighted_gpairs; linalg::Matrix<GradientPair> weighted_gpairs;
obj->GetGradient(predts, info, 0, &weighted_gpairs); obj->GetGradient(predts, info, 0, &weighted_gpairs);
auto const& h_gpairs = gpairs.ConstHostSpan(); auto const& h_gpairs = gpairs.HostView();
auto const& h_weighted_gpairs = weighted_gpairs.ConstHostSpan(); auto const& h_weighted_gpairs = weighted_gpairs.HostView();
for (size_t i : {0ul, 1ul}) { for (size_t i : {0ul, 1ul}) {
ASSERT_FLOAT_EQ(h_weighted_gpairs[i].GetGrad(), h_gpairs[i].GetGrad() * 2.0f); ASSERT_FLOAT_EQ(h_weighted_gpairs(i).GetGrad(), h_gpairs(i).GetGrad() * 2.0f);
ASSERT_FLOAT_EQ(h_weighted_gpairs[i].GetHess(), h_gpairs[i].GetHess() * 2.0f); ASSERT_FLOAT_EQ(h_weighted_gpairs(i).GetHess(), h_gpairs(i).GetHess() * 2.0f);
} }
for (size_t i : {2ul, 3ul}) { for (size_t i : {2ul, 3ul}) {
ASSERT_FLOAT_EQ(h_weighted_gpairs[i].GetGrad(), h_gpairs[i].GetGrad() * 3.0f); ASSERT_FLOAT_EQ(h_weighted_gpairs(i).GetGrad(), h_gpairs(i).GetGrad() * 3.0f);
ASSERT_FLOAT_EQ(h_weighted_gpairs[i].GetHess(), h_gpairs[i].GetHess() * 3.0f); ASSERT_FLOAT_EQ(h_weighted_gpairs(i).GetHess(), h_gpairs(i).GetHess() * 3.0f);
} }
} }
@ -125,7 +125,7 @@ void TestUnbiasedNDCG(Context const* ctx) {
std::sort(h_label.begin(), h_label.end(), std::greater<>{}); std::sort(h_label.begin(), h_label.end(), std::greater<>{});
HostDeviceVector<float> predt(p_fmat->Info().num_row_, 1.0f); HostDeviceVector<float> predt(p_fmat->Info().num_row_, 1.0f);
HostDeviceVector<GradientPair> out_gpair; linalg::Matrix<GradientPair> out_gpair;
obj->GetGradient(predt, p_fmat->Info(), 0, &out_gpair); obj->GetGradient(predt, p_fmat->Info(), 0, &out_gpair);
Json config{Object{}}; Json config{Object{}};

View File

@ -42,7 +42,8 @@ void TestGPUMakePair() {
auto d = dummy.View(ctx.gpu_id); auto d = dummy.View(ctx.gpu_id);
linalg::Vector<GradientPair> dgpair; linalg::Vector<GradientPair> dgpair;
auto dg = dgpair.View(ctx.gpu_id); auto dg = dgpair.View(ctx.gpu_id);
cuda_impl::KernelInputs args{d, cuda_impl::KernelInputs args{
d,
d, d,
d, d,
d, d,
@ -51,7 +52,7 @@ void TestGPUMakePair() {
rank_idx, rank_idx,
info.labels.View(ctx.gpu_id), info.labels.View(ctx.gpu_id),
predt.ConstDeviceSpan(), predt.ConstDeviceSpan(),
{}, linalg::MatrixView<GradientPair>{common::Span<GradientPair>{}, {0}, 0},
dg, dg,
nullptr, nullptr,
y_sorted_idx, y_sorted_idx,

View File

@ -122,7 +122,7 @@ TEST(Objective, DeclareUnifiedTest(LogisticRegressionBasic)) {
EXPECT_NEAR(obj->ProbToMargin(0.1f), -2.197f, 0.01f); EXPECT_NEAR(obj->ProbToMargin(0.1f), -2.197f, 0.01f);
EXPECT_NEAR(obj->ProbToMargin(0.5f), 0, 0.01f); EXPECT_NEAR(obj->ProbToMargin(0.5f), 0, 0.01f);
EXPECT_NEAR(obj->ProbToMargin(0.9f), 2.197f, 0.01f); EXPECT_NEAR(obj->ProbToMargin(0.9f), 2.197f, 0.01f);
EXPECT_ANY_THROW(obj->ProbToMargin(10)) EXPECT_ANY_THROW((void)obj->ProbToMargin(10))
<< "Expected error when base_score not in range [0,1f] for LogisticRegression"; << "Expected error when base_score not in range [0,1f] for LogisticRegression";
// test PredTransform // test PredTransform
@ -282,9 +282,9 @@ TEST(Objective, DeclareUnifiedTest(TweedieRegressionGPair)) {
TEST(Objective, CPU_vs_CUDA) { TEST(Objective, CPU_vs_CUDA) {
Context ctx = MakeCUDACtx(GPUIDX); Context ctx = MakeCUDACtx(GPUIDX);
ObjFunction* obj = ObjFunction::Create("reg:squarederror", &ctx); std::unique_ptr<ObjFunction> obj{ObjFunction::Create("reg:squarederror", &ctx)};
HostDeviceVector<GradientPair> cpu_out_preds; linalg::Matrix<GradientPair> cpu_out_preds;
HostDeviceVector<GradientPair> cuda_out_preds; linalg::Matrix<GradientPair> cuda_out_preds;
constexpr size_t kRows = 400; constexpr size_t kRows = 400;
constexpr size_t kCols = 100; constexpr size_t kCols = 100;
@ -300,7 +300,7 @@ TEST(Objective, CPU_vs_CUDA) {
info.labels.Reshape(kRows); info.labels.Reshape(kRows);
auto& h_labels = info.labels.Data()->HostVector(); auto& h_labels = info.labels.Data()->HostVector();
for (size_t i = 0; i < h_labels.size(); ++i) { for (size_t i = 0; i < h_labels.size(); ++i) {
h_labels[i] = 1 / (float)(i+1); h_labels[i] = 1 / static_cast<float>(i+1);
} }
{ {
@ -314,19 +314,17 @@ TEST(Objective, CPU_vs_CUDA) {
obj->GetGradient(preds, info, 0, &cuda_out_preds); obj->GetGradient(preds, info, 0, &cuda_out_preds);
} }
auto& h_cpu_out = cpu_out_preds.HostVector(); auto h_cpu_out = cpu_out_preds.HostView();
auto& h_cuda_out = cuda_out_preds.HostVector(); auto h_cuda_out = cuda_out_preds.HostView();
float sgrad = 0; float sgrad = 0;
float shess = 0; float shess = 0;
for (size_t i = 0; i < kRows; ++i) { for (size_t i = 0; i < kRows; ++i) {
sgrad += std::pow(h_cpu_out[i].GetGrad() - h_cuda_out[i].GetGrad(), 2); sgrad += std::pow(h_cpu_out(i).GetGrad() - h_cuda_out(i).GetGrad(), 2);
shess += std::pow(h_cpu_out[i].GetHess() - h_cuda_out[i].GetHess(), 2); shess += std::pow(h_cpu_out(i).GetHess() - h_cuda_out(i).GetHess(), 2);
} }
ASSERT_NEAR(sgrad, 0.0f, kRtEps); ASSERT_NEAR(sgrad, 0.0f, kRtEps);
ASSERT_NEAR(shess, 0.0f, kRtEps); ASSERT_NEAR(shess, 0.0f, kRtEps);
delete obj;
} }
#endif #endif

View File

@ -189,11 +189,10 @@ void TestUpdatePredictionCache(bool use_subsampling) {
auto dmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix(true, true, kClasses); auto dmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix(true, true, kClasses);
HostDeviceVector<GradientPair> gpair; linalg::Matrix<GradientPair> gpair({kRows, kClasses}, ctx.Device());
auto& h_gpair = gpair.HostVector(); auto h_gpair = gpair.HostView();
h_gpair.resize(kRows * kClasses);
for (size_t i = 0; i < kRows * kClasses; ++i) { for (size_t i = 0; i < kRows * kClasses; ++i) {
h_gpair[i] = {static_cast<float>(i), 1}; std::apply(h_gpair, linalg::UnravelIndex(i, kRows, kClasses)) = {static_cast<float>(i), 1};
} }
PredictionCacheEntry predtion_cache; PredictionCacheEntry predtion_cache;

View File

@ -68,10 +68,12 @@ class TestL1MultiTarget : public ::testing::Test {
} }
} }
void RunTest(std::string const& tree_method, bool weight) { void RunTest(Context const* ctx, std::string const& tree_method, bool weight) {
auto p_fmat = weight ? Xyw_ : Xy_; auto p_fmat = weight ? Xyw_ : Xy_;
std::unique_ptr<Learner> learner{Learner::Create({p_fmat})}; std::unique_ptr<Learner> learner{Learner::Create({p_fmat})};
learner->SetParams(Args{{"tree_method", tree_method}, {"objective", "reg:absoluteerror"}}); learner->SetParams(Args{{"tree_method", tree_method},
{"objective", "reg:absoluteerror"},
{"device", ctx->DeviceName()}});
learner->Configure(); learner->Configure();
for (auto i = 0; i < 4; ++i) { for (auto i = 0; i < 4; ++i) {
learner->UpdateOneIter(i, p_fmat); learner->UpdateOneIter(i, p_fmat);
@ -87,7 +89,9 @@ class TestL1MultiTarget : public ::testing::Test {
for (bst_target_t t{0}; t < p_fmat->Info().labels.Shape(1); ++t) { for (bst_target_t t{0}; t < p_fmat->Info().labels.Shape(1); ++t) {
auto t_Xy = weight ? single_w_[t] : single_[t]; auto t_Xy = weight ? single_w_[t] : single_[t];
std::unique_ptr<Learner> sl{Learner::Create({t_Xy})}; std::unique_ptr<Learner> sl{Learner::Create({t_Xy})};
sl->SetParams(Args{{"tree_method", tree_method}, {"objective", "reg:absoluteerror"}}); sl->SetParams(Args{{"tree_method", tree_method},
{"objective", "reg:absoluteerror"},
{"device", ctx->DeviceName()}});
sl->Configure(); sl->Configure();
sl->UpdateOneIter(0, t_Xy); sl->UpdateOneIter(0, t_Xy);
Json s_config{Object{}}; Json s_config{Object{}};
@ -104,20 +108,32 @@ class TestL1MultiTarget : public ::testing::Test {
ASSERT_FLOAT_EQ(mean, base_score); ASSERT_FLOAT_EQ(mean, base_score);
} }
void RunTest(std::string const& tree_method) { void RunTest(Context const* ctx, std::string const& tree_method) {
this->RunTest(tree_method, false); this->RunTest(ctx, tree_method, false);
this->RunTest(tree_method, true); this->RunTest(ctx, tree_method, true);
} }
}; };
TEST_F(TestL1MultiTarget, Hist) { this->RunTest("hist"); } TEST_F(TestL1MultiTarget, Hist) {
Context ctx;
this->RunTest(&ctx, "hist");
}
TEST_F(TestL1MultiTarget, Exact) { this->RunTest("exact"); } TEST_F(TestL1MultiTarget, Exact) {
Context ctx;
this->RunTest(&ctx, "exact");
}
TEST_F(TestL1MultiTarget, Approx) { this->RunTest("approx"); } TEST_F(TestL1MultiTarget, Approx) {
Context ctx;
this->RunTest(&ctx, "approx");
}
#if defined(XGBOOST_USE_CUDA) #if defined(XGBOOST_USE_CUDA)
TEST_F(TestL1MultiTarget, GpuHist) { this->RunTest("gpu_hist"); } TEST_F(TestL1MultiTarget, GpuHist) {
auto ctx = MakeCUDACtx(0);
this->RunTest(&ctx, "hist");
}
#endif // defined(XGBOOST_USE_CUDA) #endif // defined(XGBOOST_USE_CUDA)
TEST(MultiStrategy, Configure) { TEST(MultiStrategy, Configure) {

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2022 by XGBoost Contributors * Copyright 2022-2023, XGBoost Contributors
*/ */
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <xgboost/linalg.h> #include <xgboost/linalg.h>
@ -8,17 +8,17 @@
#include "../../src/tree/fit_stump.h" #include "../../src/tree/fit_stump.h"
#include "../helpers.h" #include "../helpers.h"
namespace xgboost { namespace xgboost::tree {
namespace tree {
namespace { namespace {
void TestFitStump(Context const *ctx, DataSplitMode split = DataSplitMode::kRow) { void TestFitStump(Context const *ctx, DataSplitMode split = DataSplitMode::kRow) {
std::size_t constexpr kRows = 16, kTargets = 2; std::size_t constexpr kRows = 16, kTargets = 2;
HostDeviceVector<GradientPair> gpair; linalg::Matrix<GradientPair> gpair;
auto &h_gpair = gpair.HostVector(); gpair.SetDevice(ctx->Device());
h_gpair.resize(kRows * kTargets); gpair.Reshape(kRows, kTargets);
auto h_gpair = gpair.HostView();
for (std::size_t i = 0; i < kRows; ++i) { for (std::size_t i = 0; i < kRows; ++i) {
for (std::size_t t = 0; t < kTargets; ++t) { for (std::size_t t = 0; t < kTargets; ++t) {
h_gpair.at(i * kTargets + t) = GradientPair{static_cast<float>(i), 1}; h_gpair(i, t) = GradientPair{static_cast<float>(i), 1};
} }
} }
linalg::Vector<float> out; linalg::Vector<float> out;
@ -53,6 +53,4 @@ TEST(InitEstimation, FitStumpColumnSplit) {
auto constexpr kWorldSize{3}; auto constexpr kWorldSize{3};
RunWithInMemoryCommunicator(kWorldSize, &TestFitStump, &ctx, DataSplitMode::kCol); RunWithInMemoryCommunicator(kWorldSize, &TestFitStump, &ctx, DataSplitMode::kCol);
} }
} // namespace xgboost::tree
} // namespace tree
} // namespace xgboost

View File

@ -214,7 +214,7 @@ TEST(GpuHist, TestHistogramIndex) {
TestHistogramIndexImpl(); TestHistogramIndexImpl();
} }
void UpdateTree(Context const* ctx, HostDeviceVector<GradientPair>* gpair, DMatrix* dmat, void UpdateTree(Context const* ctx, linalg::Matrix<GradientPair>* gpair, DMatrix* dmat,
size_t gpu_page_size, RegTree* tree, HostDeviceVector<bst_float>* preds, size_t gpu_page_size, RegTree* tree, HostDeviceVector<bst_float>* preds,
float subsample = 1.0f, const std::string& sampling_method = "uniform", float subsample = 1.0f, const std::string& sampling_method = "uniform",
int max_bin = 2) { int max_bin = 2) {
@ -264,7 +264,8 @@ TEST(GpuHist, UniformSampling) {
// Create an in-memory DMatrix. // Create an in-memory DMatrix.
std::unique_ptr<DMatrix> dmat(CreateSparsePageDMatrixWithRC(kRows, kCols, 0, true)); std::unique_ptr<DMatrix> dmat(CreateSparsePageDMatrixWithRC(kRows, kCols, 0, true));
auto gpair = GenerateRandomGradients(kRows); linalg::Matrix<GradientPair> gpair({kRows}, Context{}.MakeCUDA().Ordinal());
gpair.Data()->Copy(GenerateRandomGradients(kRows));
// Build a tree using the in-memory DMatrix. // Build a tree using the in-memory DMatrix.
RegTree tree; RegTree tree;
@ -294,7 +295,8 @@ TEST(GpuHist, GradientBasedSampling) {
// Create an in-memory DMatrix. // Create an in-memory DMatrix.
std::unique_ptr<DMatrix> dmat(CreateSparsePageDMatrixWithRC(kRows, kCols, 0, true)); std::unique_ptr<DMatrix> dmat(CreateSparsePageDMatrixWithRC(kRows, kCols, 0, true));
auto gpair = GenerateRandomGradients(kRows); linalg::Matrix<GradientPair> gpair({kRows}, MakeCUDACtx(0).Ordinal());
gpair.Data()->Copy(GenerateRandomGradients(kRows));
// Build a tree using the in-memory DMatrix. // Build a tree using the in-memory DMatrix.
RegTree tree; RegTree tree;
@ -330,11 +332,12 @@ TEST(GpuHist, ExternalMemory) {
// Create a single batch DMatrix. // Create a single batch DMatrix.
std::unique_ptr<DMatrix> dmat(CreateSparsePageDMatrix(kRows, kCols, 1, tmpdir.path + "/cache")); std::unique_ptr<DMatrix> dmat(CreateSparsePageDMatrix(kRows, kCols, 1, tmpdir.path + "/cache"));
auto gpair = GenerateRandomGradients(kRows); Context ctx(MakeCUDACtx(0));
linalg::Matrix<GradientPair> gpair({kRows}, ctx.Ordinal());
gpair.Data()->Copy(GenerateRandomGradients(kRows));
// Build a tree using the in-memory DMatrix. // Build a tree using the in-memory DMatrix.
RegTree tree; RegTree tree;
Context ctx(MakeCUDACtx(0));
HostDeviceVector<bst_float> preds(kRows, 0.0, 0); HostDeviceVector<bst_float> preds(kRows, 0.0, 0);
UpdateTree(&ctx, &gpair, dmat.get(), 0, &tree, &preds, 1.0, "uniform", kRows); UpdateTree(&ctx, &gpair, dmat.get(), 0, &tree, &preds, 1.0, "uniform", kRows);
// Build another tree using multiple ELLPACK pages. // Build another tree using multiple ELLPACK pages.
@ -367,12 +370,13 @@ TEST(GpuHist, ExternalMemoryWithSampling) {
std::unique_ptr<DMatrix> dmat_ext( std::unique_ptr<DMatrix> dmat_ext(
CreateSparsePageDMatrix(kRows, kCols, kRows / kPageSize, tmpdir.path + "/cache")); CreateSparsePageDMatrix(kRows, kCols, kRows / kPageSize, tmpdir.path + "/cache"));
auto gpair = GenerateRandomGradients(kRows); Context ctx(MakeCUDACtx(0));
linalg::Matrix<GradientPair> gpair({kRows}, ctx.Ordinal());
gpair.Data()->Copy(GenerateRandomGradients(kRows));
// Build a tree using the in-memory DMatrix. // Build a tree using the in-memory DMatrix.
auto rng = common::GlobalRandom(); auto rng = common::GlobalRandom();
Context ctx(MakeCUDACtx(0));
RegTree tree; RegTree tree;
HostDeviceVector<bst_float> preds(kRows, 0.0, 0); HostDeviceVector<bst_float> preds(kRows, 0.0, 0);
UpdateTree(&ctx, &gpair, dmat.get(), 0, &tree, &preds, kSubsample, kSamplingMethod, kRows); UpdateTree(&ctx, &gpair, dmat.get(), 0, &tree, &preds, kSubsample, kSamplingMethod, kRows);

View File

@ -26,9 +26,11 @@ TEST(GrowHistMaker, InteractionConstraint) {
auto constexpr kRows = 32; auto constexpr kRows = 32;
auto constexpr kCols = 16; auto constexpr kCols = 16;
auto p_dmat = GenerateDMatrix(kRows, kCols); auto p_dmat = GenerateDMatrix(kRows, kCols);
auto p_gradients = GenerateGradients(kRows);
Context ctx; Context ctx;
linalg::Matrix<GradientPair> gpair({kRows}, ctx.Ordinal());
gpair.Data()->Copy(GenerateRandomGradients(kRows));
ObjInfo task{ObjInfo::kRegression}; ObjInfo task{ObjInfo::kRegression};
{ {
// With constraints // With constraints
@ -40,7 +42,7 @@ TEST(GrowHistMaker, InteractionConstraint) {
Args{{"interaction_constraints", "[[0, 1]]"}, {"num_feature", std::to_string(kCols)}}); Args{{"interaction_constraints", "[[0, 1]]"}, {"num_feature", std::to_string(kCols)}});
std::vector<HostDeviceVector<bst_node_t>> position(1); std::vector<HostDeviceVector<bst_node_t>> position(1);
updater->Configure(Args{}); updater->Configure(Args{});
updater->Update(&param, p_gradients.get(), p_dmat.get(), position, {&tree}); updater->Update(&param, &gpair, p_dmat.get(), position, {&tree});
ASSERT_EQ(tree.NumExtraNodes(), 4); ASSERT_EQ(tree.NumExtraNodes(), 4);
ASSERT_EQ(tree[0].SplitIndex(), 1); ASSERT_EQ(tree[0].SplitIndex(), 1);
@ -57,7 +59,7 @@ TEST(GrowHistMaker, InteractionConstraint) {
TrainParam param; TrainParam param;
param.Init(Args{}); param.Init(Args{});
updater->Configure(Args{}); updater->Configure(Args{});
updater->Update(&param, p_gradients.get(), p_dmat.get(), position, {&tree}); updater->Update(&param, &gpair, p_dmat.get(), position, {&tree});
ASSERT_EQ(tree.NumExtraNodes(), 10); ASSERT_EQ(tree.NumExtraNodes(), 10);
ASSERT_EQ(tree[0].SplitIndex(), 1); ASSERT_EQ(tree[0].SplitIndex(), 1);
@ -70,9 +72,12 @@ TEST(GrowHistMaker, InteractionConstraint) {
namespace { namespace {
void VerifyColumnSplit(int32_t rows, bst_feature_t cols, bool categorical, void VerifyColumnSplit(int32_t rows, bst_feature_t cols, bool categorical,
RegTree const& expected_tree) { RegTree const& expected_tree) {
auto p_dmat = GenerateDMatrix(rows, cols, categorical);
auto p_gradients = GenerateGradients(rows);
Context ctx; Context ctx;
auto p_dmat = GenerateDMatrix(rows, cols, categorical);
linalg::Matrix<GradientPair> gpair({rows}, ctx.Ordinal());
gpair.Data()->Copy(GenerateRandomGradients(rows));
ObjInfo task{ObjInfo::kRegression}; ObjInfo task{ObjInfo::kRegression};
std::unique_ptr<TreeUpdater> updater{TreeUpdater::Create("grow_histmaker", &ctx, &task)}; std::unique_ptr<TreeUpdater> updater{TreeUpdater::Create("grow_histmaker", &ctx, &task)};
std::vector<HostDeviceVector<bst_node_t>> position(1); std::vector<HostDeviceVector<bst_node_t>> position(1);
@ -84,7 +89,7 @@ void VerifyColumnSplit(int32_t rows, bst_feature_t cols, bool categorical,
TrainParam param; TrainParam param;
param.Init(Args{}); param.Init(Args{});
updater->Configure(Args{}); updater->Configure(Args{});
updater->Update(&param, p_gradients.get(), sliced.get(), position, {&tree}); updater->Update(&param, &gpair, sliced.get(), position, {&tree});
Json json{Object{}}; Json json{Object{}};
tree.SaveModel(&json); tree.SaveModel(&json);
@ -100,15 +105,16 @@ void TestColumnSplit(bool categorical) {
RegTree expected_tree{1u, kCols}; RegTree expected_tree{1u, kCols};
ObjInfo task{ObjInfo::kRegression}; ObjInfo task{ObjInfo::kRegression};
{ {
auto p_dmat = GenerateDMatrix(kRows, kCols, categorical);
auto p_gradients = GenerateGradients(kRows);
Context ctx; Context ctx;
auto p_dmat = GenerateDMatrix(kRows, kCols, categorical);
linalg::Matrix<GradientPair> gpair({kRows}, ctx.Ordinal());
gpair.Data()->Copy(GenerateRandomGradients(kRows));
std::unique_ptr<TreeUpdater> updater{TreeUpdater::Create("grow_histmaker", &ctx, &task)}; std::unique_ptr<TreeUpdater> updater{TreeUpdater::Create("grow_histmaker", &ctx, &task)};
std::vector<HostDeviceVector<bst_node_t>> position(1); std::vector<HostDeviceVector<bst_node_t>> position(1);
TrainParam param; TrainParam param;
param.Init(Args{}); param.Init(Args{});
updater->Configure(Args{}); updater->Configure(Args{});
updater->Update(&param, p_gradients.get(), p_dmat.get(), position, {&expected_tree}); updater->Update(&param, &gpair, p_dmat.get(), position, {&expected_tree});
} }
auto constexpr kWorldSize = 2; auto constexpr kWorldSize = 2;

View File

@ -69,7 +69,7 @@ class TestPredictionCache : public ::testing::Test {
std::unique_ptr<TreeUpdater> updater{TreeUpdater::Create(updater_name, ctx, &task)}; std::unique_ptr<TreeUpdater> updater{TreeUpdater::Create(updater_name, ctx, &task)};
RegTree tree; RegTree tree;
std::vector<RegTree*> trees{&tree}; std::vector<RegTree*> trees{&tree};
auto gpair = GenerateRandomGradients(n_samples_); auto gpair = GenerateRandomGradients(ctx, n_samples_, 1);
tree::TrainParam param; tree::TrainParam param;
param.UpdateAllowUnknown(Args{{"max_bin", "64"}}); param.UpdateAllowUnknown(Args{{"max_bin", "64"}});

View File

@ -21,15 +21,13 @@ TEST(Updater, Prune) {
std::vector<std::pair<std::string, std::string>> cfg; std::vector<std::pair<std::string, std::string>> cfg;
cfg.emplace_back("num_feature", std::to_string(kCols)); cfg.emplace_back("num_feature", std::to_string(kCols));
cfg.emplace_back("min_split_loss", "10"); cfg.emplace_back("min_split_loss", "10");
Context ctx;
// These data are just place holders. // These data are just place holders.
HostDeviceVector<GradientPair> gpair = linalg::Matrix<GradientPair> gpair
{ {0.50f, 0.25f}, {0.50f, 0.25f}, {0.50f, 0.25f}, {0.50f, 0.25f}, {{ {0.50f, 0.25f}, {0.50f, 0.25f}, {0.50f, 0.25f}, {0.50f, 0.25f},
{0.25f, 0.24f}, {0.25f, 0.24f}, {0.25f, 0.24f}, {0.25f, 0.24f} }; {0.25f, 0.24f}, {0.25f, 0.24f}, {0.25f, 0.24f}, {0.25f, 0.24f} }, {8, 1}, ctx.Device()};
std::shared_ptr<DMatrix> p_dmat { std::shared_ptr<DMatrix> p_dmat{RandomDataGenerator{32, 10, 0}.GenerateDMatrix()};
RandomDataGenerator{32, 10, 0}.GenerateDMatrix() };
Context ctx;
// prepare tree // prepare tree
RegTree tree = RegTree{1u, kCols}; RegTree tree = RegTree{1u, kCols};

View File

@ -202,13 +202,13 @@ TEST(QuantileHist, PartitionerColSplit) { TestColumnSplitPartitioner<CPUExpandEn
TEST(QuantileHist, MultiPartitionerColSplit) { TestColumnSplitPartitioner<MultiExpandEntry>(3); } TEST(QuantileHist, MultiPartitionerColSplit) { TestColumnSplitPartitioner<MultiExpandEntry>(3); }
namespace { namespace {
void VerifyColumnSplit(bst_row_t rows, bst_feature_t cols, bst_target_t n_targets, void VerifyColumnSplit(Context const* ctx, bst_row_t rows, bst_feature_t cols, bst_target_t n_targets,
RegTree const& expected_tree) { RegTree const& expected_tree) {
auto Xy = RandomDataGenerator{rows, cols, 0}.GenerateDMatrix(true); auto Xy = RandomDataGenerator{rows, cols, 0}.GenerateDMatrix(true);
auto p_gradients = GenerateGradients(rows, n_targets); linalg::Matrix<GradientPair> gpair = GenerateRandomGradients(ctx, rows, n_targets);
Context ctx;
ObjInfo task{ObjInfo::kRegression}; ObjInfo task{ObjInfo::kRegression};
std::unique_ptr<TreeUpdater> updater{TreeUpdater::Create("grow_quantile_histmaker", &ctx, &task)}; std::unique_ptr<TreeUpdater> updater{TreeUpdater::Create("grow_quantile_histmaker", ctx, &task)};
std::vector<HostDeviceVector<bst_node_t>> position(1); std::vector<HostDeviceVector<bst_node_t>> position(1);
std::unique_ptr<DMatrix> sliced{Xy->SliceCol(collective::GetWorldSize(), collective::GetRank())}; std::unique_ptr<DMatrix> sliced{Xy->SliceCol(collective::GetWorldSize(), collective::GetRank())};
@ -217,7 +217,7 @@ void VerifyColumnSplit(bst_row_t rows, bst_feature_t cols, bst_target_t n_target
TrainParam param; TrainParam param;
param.Init(Args{}); param.Init(Args{});
updater->Configure(Args{}); updater->Configure(Args{});
updater->Update(&param, p_gradients.get(), sliced.get(), position, {&tree}); updater->Update(&param, &gpair, sliced.get(), position, {&tree});
Json json{Object{}}; Json json{Object{}};
tree.SaveModel(&json); tree.SaveModel(&json);
@ -232,21 +232,21 @@ void TestColumnSplit(bst_target_t n_targets) {
RegTree expected_tree{n_targets, kCols}; RegTree expected_tree{n_targets, kCols};
ObjInfo task{ObjInfo::kRegression}; ObjInfo task{ObjInfo::kRegression};
Context ctx;
{ {
auto Xy = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(true); auto Xy = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(true);
auto p_gradients = GenerateGradients(kRows, n_targets); auto gpair = GenerateRandomGradients(&ctx, kRows, n_targets);
Context ctx;
std::unique_ptr<TreeUpdater> updater{ std::unique_ptr<TreeUpdater> updater{
TreeUpdater::Create("grow_quantile_histmaker", &ctx, &task)}; TreeUpdater::Create("grow_quantile_histmaker", &ctx, &task)};
std::vector<HostDeviceVector<bst_node_t>> position(1); std::vector<HostDeviceVector<bst_node_t>> position(1);
TrainParam param; TrainParam param;
param.Init(Args{}); param.Init(Args{});
updater->Configure(Args{}); updater->Configure(Args{});
updater->Update(&param, p_gradients.get(), Xy.get(), position, {&expected_tree}); updater->Update(&param, &gpair, Xy.get(), position, {&expected_tree});
} }
auto constexpr kWorldSize = 2; auto constexpr kWorldSize = 2;
RunWithInMemoryCommunicator(kWorldSize, VerifyColumnSplit, kRows, kCols, n_targets, RunWithInMemoryCommunicator(kWorldSize, VerifyColumnSplit, &ctx, kRows, kCols, n_targets,
std::cref(expected_tree)); std::cref(expected_tree));
} }
} // anonymous namespace } // anonymous namespace

View File

@ -17,10 +17,11 @@ namespace xgboost::tree {
TEST(Updater, Refresh) { TEST(Updater, Refresh) {
bst_row_t constexpr kRows = 8; bst_row_t constexpr kRows = 8;
bst_feature_t constexpr kCols = 16; bst_feature_t constexpr kCols = 16;
Context ctx;
HostDeviceVector<GradientPair> gpair = linalg::Matrix<GradientPair> gpair
{ {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f}, {{ {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f},
{0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f} }; {0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f} }, {8, 1}, ctx.Device()};
std::shared_ptr<DMatrix> p_dmat{ std::shared_ptr<DMatrix> p_dmat{
RandomDataGenerator{kRows, kCols, 0.4f}.Seed(3).GenerateDMatrix()}; RandomDataGenerator{kRows, kCols, 0.4f}.Seed(3).GenerateDMatrix()};
std::vector<std::pair<std::string, std::string>> cfg{ std::vector<std::pair<std::string, std::string>> cfg{
@ -29,7 +30,6 @@ TEST(Updater, Refresh) {
{"reg_lambda", "1"}}; {"reg_lambda", "1"}};
RegTree tree = RegTree{1u, kCols}; RegTree tree = RegTree{1u, kCols};
Context ctx;
std::vector<RegTree*> trees{&tree}; std::vector<RegTree*> trees{&tree};
ObjInfo task{ObjInfo::kRegression}; ObjInfo task{ObjInfo::kRegression};

View File

@ -16,7 +16,7 @@ namespace xgboost {
class UpdaterTreeStatTest : public ::testing::Test { class UpdaterTreeStatTest : public ::testing::Test {
protected: protected:
std::shared_ptr<DMatrix> p_dmat_; std::shared_ptr<DMatrix> p_dmat_;
HostDeviceVector<GradientPair> gpairs_; linalg::Matrix<GradientPair> gpairs_;
size_t constexpr static kRows = 10; size_t constexpr static kRows = 10;
size_t constexpr static kCols = 10; size_t constexpr static kCols = 10;
@ -24,8 +24,8 @@ class UpdaterTreeStatTest : public ::testing::Test {
void SetUp() override { void SetUp() override {
p_dmat_ = RandomDataGenerator(kRows, kCols, .5f).GenerateDMatrix(true); p_dmat_ = RandomDataGenerator(kRows, kCols, .5f).GenerateDMatrix(true);
auto g = GenerateRandomGradients(kRows); auto g = GenerateRandomGradients(kRows);
gpairs_.Resize(kRows); gpairs_.Reshape(kRows, 1);
gpairs_.Copy(g); gpairs_.Data()->Copy(g);
} }
void RunTest(std::string updater) { void RunTest(std::string updater) {
@ -63,7 +63,7 @@ TEST_F(UpdaterTreeStatTest, Approx) { this->RunTest("grow_histmaker"); }
class UpdaterEtaTest : public ::testing::Test { class UpdaterEtaTest : public ::testing::Test {
protected: protected:
std::shared_ptr<DMatrix> p_dmat_; std::shared_ptr<DMatrix> p_dmat_;
HostDeviceVector<GradientPair> gpairs_; linalg::Matrix<GradientPair> gpairs_;
size_t constexpr static kRows = 10; size_t constexpr static kRows = 10;
size_t constexpr static kCols = 10; size_t constexpr static kCols = 10;
size_t constexpr static kClasses = 10; size_t constexpr static kClasses = 10;
@ -71,8 +71,8 @@ class UpdaterEtaTest : public ::testing::Test {
void SetUp() override { void SetUp() override {
p_dmat_ = RandomDataGenerator(kRows, kCols, .5f).GenerateDMatrix(true, false, kClasses); p_dmat_ = RandomDataGenerator(kRows, kCols, .5f).GenerateDMatrix(true, false, kClasses);
auto g = GenerateRandomGradients(kRows); auto g = GenerateRandomGradients(kRows);
gpairs_.Resize(kRows); gpairs_.Reshape(kRows, 1);
gpairs_.Copy(g); gpairs_.Data()->Copy(g);
} }
void RunTest(std::string updater) { void RunTest(std::string updater) {
@ -125,14 +125,15 @@ TEST_F(UpdaterEtaTest, GpuHist) { this->RunTest("grow_gpu_hist"); }
class TestMinSplitLoss : public ::testing::Test { class TestMinSplitLoss : public ::testing::Test {
std::shared_ptr<DMatrix> dmat_; std::shared_ptr<DMatrix> dmat_;
HostDeviceVector<GradientPair> gpair_; linalg::Matrix<GradientPair> gpair_;
void SetUp() override { void SetUp() override {
constexpr size_t kRows = 32; constexpr size_t kRows = 32;
constexpr size_t kCols = 16; constexpr size_t kCols = 16;
constexpr float kSparsity = 0.6; constexpr float kSparsity = 0.6;
dmat_ = RandomDataGenerator(kRows, kCols, kSparsity).Seed(3).GenerateDMatrix(); dmat_ = RandomDataGenerator(kRows, kCols, kSparsity).Seed(3).GenerateDMatrix();
gpair_ = GenerateRandomGradients(kRows); gpair_.Reshape(kRows, 1);
gpair_.Data()->Copy(GenerateRandomGradients(kRows));
} }
std::int32_t Update(Context const* ctx, std::string updater, float gamma) { std::int32_t Update(Context const* ctx, std::string updater, float gamma) {

View File

@ -1,3 +1,4 @@
import itertools
import json import json
import os import os
import sys import sys
@ -158,6 +159,96 @@ def test_classififer():
clf.fit(X, y) clf.fit(X, y)
@pytest.mark.parametrize(
"use_cupy,tree_method,device,order,gdtype,strategy",
[
c
for c in itertools.product(
(True, False),
("hist", "approx"),
("cpu", "cuda"),
("C", "F"),
("float64", "float32"),
("one_output_per_tree", "multi_output_tree"),
)
],
)
def test_custom_objective(
use_cupy: bool,
tree_method: str,
device: str,
order: str,
gdtype: str,
strategy: str,
) -> None:
from sklearn.datasets import load_iris
X, y = load_iris(return_X_y=True)
params = {
"tree_method": tree_method,
"device": device,
"n_estimators": 8,
"multi_strategy": strategy,
}
obj = tm.softprob_obj(y.max() + 1, use_cupy=use_cupy, order=order, gdtype=gdtype)
clf = xgb.XGBClassifier(objective=obj, **params)
if strategy == "multi_output_tree" and tree_method == "approx":
with pytest.raises(ValueError, match=r"Only the hist"):
clf.fit(X, y)
return
if strategy == "multi_output_tree" and device == "cuda":
with pytest.raises(ValueError, match=r"GPU is not yet"):
clf.fit(X, y)
return
clf.fit(X, y)
clf_1 = xgb.XGBClassifier(**params)
clf_1.fit(X, y)
np.testing.assert_allclose(clf.predict_proba(X), clf_1.predict_proba(X), rtol=1e-4)
params["n_estimators"] = 2
def wrong_shape(labels, predt):
grad, hess = obj(labels, predt)
return grad[:, :-1], hess[:, :-1]
with pytest.raises(ValueError, match="should be equal to the number of"):
clf = xgb.XGBClassifier(objective=wrong_shape, **params)
clf.fit(X, y)
def wrong_shape_1(labels, predt):
grad, hess = obj(labels, predt)
return grad[:-1, :], hess[:-1, :]
with pytest.raises(ValueError, match="Mismatched size between the gradient"):
clf = xgb.XGBClassifier(objective=wrong_shape_1, **params)
clf.fit(X, y)
def wrong_shape_2(labels, predt):
grad, hess = obj(labels, predt)
return grad[:, :], hess[:-1, :]
with pytest.raises(ValueError, match="Mismatched shape between the gradient"):
clf = xgb.XGBClassifier(objective=wrong_shape_2, **params)
clf.fit(X, y)
def wrong_shape_3(labels, predt):
grad, hess = obj(labels, predt)
grad = grad.reshape(grad.size)
hess = hess.reshape(hess.size)
return grad, hess
with pytest.warns(FutureWarning, match="required to be"):
clf = xgb.XGBClassifier(objective=wrong_shape_3, **params)
clf.fit(X, y)
@pytest.mark.skipif(**tm.no_pandas()) @pytest.mark.skipif(**tm.no_pandas())
def test_ranking_qid_df(): def test_ranking_qid_df():
import cudf import cudf