Remove omp_get_max_threads (#7608)
This is the one last PR for removing omp global variable. * Add context object to the `DMatrix`. This bridges `DMatrix` with https://github.com/dmlc/xgboost/issues/7308 . * Require context to be available at the construction time of booster. * Add `n_threads` support for R csc DMatrix constructor. * Remove `omp_get_max_threads` in R glue code. * Remove threading utilities that rely on omp global variable.
This commit is contained in:
parent
028bdc1740
commit
81210420c6
@ -33,7 +33,9 @@ xgb.DMatrix <- function(data, info = list(), missing = NA, silent = FALSE, nthre
|
||||
handle <- .Call(XGDMatrixCreateFromMat_R, data, missing, as.integer(NVL(nthread, -1)))
|
||||
cnames <- colnames(data)
|
||||
} else if (inherits(data, "dgCMatrix")) {
|
||||
handle <- .Call(XGDMatrixCreateFromCSC_R, data@p, data@i, data@x, nrow(data))
|
||||
handle <- .Call(
|
||||
XGDMatrixCreateFromCSC_R, data@p, data@i, data@x, nrow(data), as.integer(NVL(nthread, -1))
|
||||
)
|
||||
cnames <- colnames(data)
|
||||
} else {
|
||||
stop("xgb.DMatrix does not support construction from ", typeof(data))
|
||||
|
||||
@ -37,7 +37,7 @@ extern SEXP XGBoosterSetAttr_R(SEXP, SEXP, SEXP);
|
||||
extern SEXP XGBoosterSetParam_R(SEXP, SEXP, SEXP);
|
||||
extern SEXP XGBoosterUpdateOneIter_R(SEXP, SEXP, SEXP);
|
||||
extern SEXP XGCheckNullPtr_R(SEXP);
|
||||
extern SEXP XGDMatrixCreateFromCSC_R(SEXP, SEXP, SEXP, SEXP);
|
||||
extern SEXP XGDMatrixCreateFromCSC_R(SEXP, SEXP, SEXP, SEXP, SEXP);
|
||||
extern SEXP XGDMatrixCreateFromFile_R(SEXP, SEXP);
|
||||
extern SEXP XGDMatrixCreateFromMat_R(SEXP, SEXP, SEXP);
|
||||
extern SEXP XGDMatrixGetInfo_R(SEXP, SEXP);
|
||||
@ -72,7 +72,7 @@ static const R_CallMethodDef CallEntries[] = {
|
||||
{"XGBoosterSetParam_R", (DL_FUNC) &XGBoosterSetParam_R, 3},
|
||||
{"XGBoosterUpdateOneIter_R", (DL_FUNC) &XGBoosterUpdateOneIter_R, 3},
|
||||
{"XGCheckNullPtr_R", (DL_FUNC) &XGCheckNullPtr_R, 1},
|
||||
{"XGDMatrixCreateFromCSC_R", (DL_FUNC) &XGDMatrixCreateFromCSC_R, 4},
|
||||
{"XGDMatrixCreateFromCSC_R", (DL_FUNC) &XGDMatrixCreateFromCSC_R, 5},
|
||||
{"XGDMatrixCreateFromFile_R", (DL_FUNC) &XGDMatrixCreateFromFile_R, 2},
|
||||
{"XGDMatrixCreateFromMat_R", (DL_FUNC) &XGDMatrixCreateFromMat_R, 3},
|
||||
{"XGDMatrixGetInfo_R", (DL_FUNC) &XGDMatrixGetInfo_R, 2},
|
||||
|
||||
@ -1,16 +1,23 @@
|
||||
// Copyright (c) 2014 by Contributors
|
||||
#include <dmlc/logging.h>
|
||||
#include <dmlc/omp.h>
|
||||
/**
|
||||
* Copyright 2014-2022 by XGBoost Contributors
|
||||
*/
|
||||
#include <dmlc/common.h>
|
||||
#include <dmlc/omp.h>
|
||||
#include <xgboost/c_api.h>
|
||||
#include <vector>
|
||||
#include <xgboost/data.h>
|
||||
#include <xgboost/generic_parameters.h>
|
||||
#include <xgboost/logging.h>
|
||||
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <cstring>
|
||||
#include <cstdio>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
|
||||
#include "../../src/c_api/c_api_error.h"
|
||||
#include "../../src/common/threading_utils.h"
|
||||
|
||||
#include "./xgboost_R.h"
|
||||
|
||||
/*!
|
||||
@ -37,8 +44,21 @@
|
||||
error(XGBGetLastError()); \
|
||||
}
|
||||
|
||||
using dmlc::BeginPtr;
|
||||
|
||||
using namespace dmlc;
|
||||
xgboost::GenericParameter const *BoosterCtx(BoosterHandle handle) {
|
||||
CHECK_HANDLE();
|
||||
auto *learner = static_cast<xgboost::Learner *>(handle);
|
||||
CHECK(learner);
|
||||
return learner->Ctx();
|
||||
}
|
||||
|
||||
xgboost::GenericParameter const *DMatrixCtx(DMatrixHandle handle) {
|
||||
CHECK_HANDLE();
|
||||
auto p_m = static_cast<std::shared_ptr<xgboost::DMatrix> *>(handle);
|
||||
CHECK(p_m);
|
||||
return p_m->get()->Ctx();
|
||||
}
|
||||
|
||||
XGB_DLL SEXP XGCheckNullPtr_R(SEXP handle) {
|
||||
return ScalarLogical(R_ExternalPtrAddr(handle) == NULL);
|
||||
@ -94,18 +114,13 @@ XGB_DLL SEXP XGDMatrixCreateFromMat_R(SEXP mat, SEXP missing, SEXP n_threads) {
|
||||
din = REAL(mat);
|
||||
}
|
||||
std::vector<float> data(nrow * ncol);
|
||||
dmlc::OMPException exc;
|
||||
int32_t threads = xgboost::common::OmpGetNumThreads(asInteger(n_threads));
|
||||
|
||||
#pragma omp parallel for schedule(static) num_threads(threads)
|
||||
for (omp_ulong i = 0; i < nrow; ++i) {
|
||||
exc.Run([&]() {
|
||||
for (size_t j = 0; j < ncol; ++j) {
|
||||
data[i * ncol +j] = is_int ? static_cast<float>(iin[i + nrow * j]) : din[i + nrow * j];
|
||||
}
|
||||
});
|
||||
}
|
||||
exc.Rethrow();
|
||||
xgboost::common::ParallelFor(nrow, threads, [&](xgboost::omp_ulong i) {
|
||||
for (size_t j = 0; j < ncol; ++j) {
|
||||
data[i * ncol + j] = is_int ? static_cast<float>(iin[i + nrow * j]) : din[i + nrow * j];
|
||||
}
|
||||
});
|
||||
DMatrixHandle handle;
|
||||
CHECK_CALL(XGDMatrixCreateFromMat_omp(BeginPtr(data), nrow, ncol,
|
||||
asReal(missing), &handle, threads));
|
||||
@ -117,7 +132,7 @@ XGB_DLL SEXP XGDMatrixCreateFromMat_R(SEXP mat, SEXP missing, SEXP n_threads) {
|
||||
}
|
||||
|
||||
XGB_DLL SEXP XGDMatrixCreateFromCSC_R(SEXP indptr, SEXP indices, SEXP data,
|
||||
SEXP num_row) {
|
||||
SEXP num_row, SEXP n_threads) {
|
||||
SEXP ret;
|
||||
R_API_BEGIN();
|
||||
const int *p_indptr = INTEGER(indptr);
|
||||
@ -133,15 +148,11 @@ XGB_DLL SEXP XGDMatrixCreateFromCSC_R(SEXP indptr, SEXP indices, SEXP data,
|
||||
for (size_t i = 0; i < nindptr; ++i) {
|
||||
col_ptr_[i] = static_cast<size_t>(p_indptr[i]);
|
||||
}
|
||||
dmlc::OMPException exc;
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (int64_t i = 0; i < static_cast<int64_t>(ndata); ++i) {
|
||||
exc.Run([&]() {
|
||||
indices_[i] = static_cast<unsigned>(p_indices[i]);
|
||||
data_[i] = static_cast<float>(p_data[i]);
|
||||
});
|
||||
}
|
||||
exc.Rethrow();
|
||||
int32_t threads = xgboost::common::OmpGetNumThreads(asInteger(n_threads));
|
||||
xgboost::common::ParallelFor(ndata, threads, [&](xgboost::omp_ulong i) {
|
||||
indices_[i] = static_cast<unsigned>(p_indices[i]);
|
||||
data_[i] = static_cast<float>(p_data[i]);
|
||||
});
|
||||
DMatrixHandle handle;
|
||||
CHECK_CALL(XGDMatrixCreateFromCSCEx(BeginPtr(col_ptr_), BeginPtr(indices_),
|
||||
BeginPtr(data_), nindptr, ndata,
|
||||
@ -186,31 +197,20 @@ XGB_DLL SEXP XGDMatrixSetInfo_R(SEXP handle, SEXP field, SEXP array) {
|
||||
R_API_BEGIN();
|
||||
int len = length(array);
|
||||
const char *name = CHAR(asChar(field));
|
||||
dmlc::OMPException exc;
|
||||
auto ctx = DMatrixCtx(R_ExternalPtrAddr(handle));
|
||||
if (!strcmp("group", name)) {
|
||||
std::vector<unsigned> vec(len);
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (int i = 0; i < len; ++i) {
|
||||
exc.Run([&]() {
|
||||
vec[i] = static_cast<unsigned>(INTEGER(array)[i]);
|
||||
});
|
||||
}
|
||||
exc.Rethrow();
|
||||
CHECK_CALL(XGDMatrixSetUIntInfo(R_ExternalPtrAddr(handle),
|
||||
CHAR(asChar(field)),
|
||||
BeginPtr(vec), len));
|
||||
xgboost::common::ParallelFor(len, ctx->Threads(), [&](xgboost::omp_ulong i) {
|
||||
vec[i] = static_cast<unsigned>(INTEGER(array)[i]);
|
||||
});
|
||||
CHECK_CALL(
|
||||
XGDMatrixSetUIntInfo(R_ExternalPtrAddr(handle), CHAR(asChar(field)), BeginPtr(vec), len));
|
||||
} else {
|
||||
std::vector<float> vec(len);
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (int i = 0; i < len; ++i) {
|
||||
exc.Run([&]() {
|
||||
vec[i] = REAL(array)[i];
|
||||
});
|
||||
}
|
||||
exc.Rethrow();
|
||||
CHECK_CALL(XGDMatrixSetFloatInfo(R_ExternalPtrAddr(handle),
|
||||
CHAR(asChar(field)),
|
||||
BeginPtr(vec), len));
|
||||
xgboost::common::ParallelFor(len, ctx->Threads(),
|
||||
[&](xgboost::omp_ulong i) { vec[i] = REAL(array)[i]; });
|
||||
CHECK_CALL(
|
||||
XGDMatrixSetFloatInfo(R_ExternalPtrAddr(handle), CHAR(asChar(field)), BeginPtr(vec), len));
|
||||
}
|
||||
R_API_END();
|
||||
return R_NilValue;
|
||||
@ -313,15 +313,11 @@ XGB_DLL SEXP XGBoosterBoostOneIter_R(SEXP handle, SEXP dtrain, SEXP grad, SEXP h
|
||||
<< "gradient and hess must have same length";
|
||||
int len = length(grad);
|
||||
std::vector<float> tgrad(len), thess(len);
|
||||
dmlc::OMPException exc;
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (int j = 0; j < len; ++j) {
|
||||
exc.Run([&]() {
|
||||
tgrad[j] = REAL(grad)[j];
|
||||
thess[j] = REAL(hess)[j];
|
||||
});
|
||||
}
|
||||
exc.Rethrow();
|
||||
auto ctx = BoosterCtx(R_ExternalPtrAddr(handle));
|
||||
xgboost::common::ParallelFor(len, ctx->Threads(), [&](xgboost::omp_ulong j) {
|
||||
tgrad[j] = REAL(grad)[j];
|
||||
thess[j] = REAL(hess)[j];
|
||||
});
|
||||
CHECK_CALL(XGBoosterBoostOneIter(R_ExternalPtrAddr(handle),
|
||||
R_ExternalPtrAddr(dtrain),
|
||||
BeginPtr(tgrad), BeginPtr(thess),
|
||||
@ -398,11 +394,10 @@ XGB_DLL SEXP XGBoosterPredictFromDMatrix_R(SEXP handle, SEXP dmat, SEXP json_con
|
||||
len *= out_shape[i];
|
||||
}
|
||||
r_out_result = PROTECT(allocVector(REALSXP, len));
|
||||
|
||||
#pragma omp parallel for
|
||||
for (omp_ulong i = 0; i < len; ++i) {
|
||||
auto ctx = BoosterCtx(R_ExternalPtrAddr(handle));
|
||||
xgboost::common::ParallelFor(len, ctx->Threads(), [&](xgboost::omp_ulong i) {
|
||||
REAL(r_out_result)[i] = out_result[i];
|
||||
}
|
||||
});
|
||||
|
||||
r_out = PROTECT(allocVector(VECSXP, 2));
|
||||
|
||||
@ -600,7 +595,6 @@ XGB_DLL SEXP XGBoosterFeatureScore_R(SEXP handle, SEXP json_config) {
|
||||
CHECK_CALL(XGBoosterFeatureScore(R_ExternalPtrAddr(handle), c_json_config,
|
||||
&out_n_features, &out_features,
|
||||
&out_dim, &out_shape, &out_scores));
|
||||
|
||||
out_shape_sexp = PROTECT(allocVector(INTSXP, out_dim));
|
||||
size_t len = 1;
|
||||
for (size_t i = 0; i < out_dim; ++i) {
|
||||
@ -609,10 +603,10 @@ XGB_DLL SEXP XGBoosterFeatureScore_R(SEXP handle, SEXP json_config) {
|
||||
}
|
||||
|
||||
out_scores_sexp = PROTECT(allocVector(REALSXP, len));
|
||||
#pragma omp parallel for
|
||||
for (omp_ulong i = 0; i < len; ++i) {
|
||||
auto ctx = BoosterCtx(R_ExternalPtrAddr(handle));
|
||||
xgboost::common::ParallelFor(len, ctx->Threads(), [&](xgboost::omp_ulong i) {
|
||||
REAL(out_scores_sexp)[i] = out_scores[i];
|
||||
}
|
||||
});
|
||||
|
||||
out_features_sexp = PROTECT(allocVector(STRSXP, out_n_features));
|
||||
for (size_t i = 0; i < out_n_features; ++i) {
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2014 (c) by Contributors
|
||||
* Copyright 2014-2022 by XGBoost Contributors
|
||||
* \file xgboost_R.h
|
||||
* \author Tianqi Chen
|
||||
* \brief R wrapper of xgboost
|
||||
@ -59,12 +59,11 @@ XGB_DLL SEXP XGDMatrixCreateFromMat_R(SEXP mat,
|
||||
* \param indices row indices
|
||||
* \param data content of the data
|
||||
* \param num_row numer of rows (when it's set to 0, then guess from data)
|
||||
* \param n_threads Number of threads used to construct DMatrix from csc matrix.
|
||||
* \return created dmatrix
|
||||
*/
|
||||
XGB_DLL SEXP XGDMatrixCreateFromCSC_R(SEXP indptr,
|
||||
SEXP indices,
|
||||
SEXP data,
|
||||
SEXP num_row);
|
||||
XGB_DLL SEXP XGDMatrixCreateFromCSC_R(SEXP indptr, SEXP indices, SEXP data, SEXP num_row,
|
||||
SEXP n_threads);
|
||||
|
||||
/*!
|
||||
* \brief create a new dmatrix from sliced content of existing matrix
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright (c) 2015-2022 by Contributors
|
||||
* Copyright (c) 2015-2022 by XGBoost Contributors
|
||||
* \file data.h
|
||||
* \brief The input data structure of xgboost.
|
||||
* \author Tianqi Chen
|
||||
@ -11,6 +11,7 @@
|
||||
#include <dmlc/data.h>
|
||||
#include <dmlc/serializer.h>
|
||||
#include <xgboost/base.h>
|
||||
#include <xgboost/generic_parameters.h>
|
||||
#include <xgboost/host_device_vector.h>
|
||||
#include <xgboost/linalg.h>
|
||||
#include <xgboost/span.h>
|
||||
@ -467,6 +468,11 @@ class DMatrix {
|
||||
|
||||
/*! \brief Get thread local memory for returning data from DMatrix. */
|
||||
XGBAPIThreadLocalEntry& GetThreadLocal() const;
|
||||
/**
|
||||
* \brief Get the context object of this DMatrix. The context is created during construction of
|
||||
* DMatrix with user specified `nthread` parameter.
|
||||
*/
|
||||
virtual GenericParameter const* Ctx() const = 0;
|
||||
|
||||
/**
|
||||
* \brief Gets batches. Use range based for loop over BatchSet to access individual batches.
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2014-2021 by Contributors
|
||||
* Copyright 2014-2022 by XGBoost Contributors
|
||||
* \file gbm.h
|
||||
* \brief Interface of gradient booster,
|
||||
* that learns through gradient statistics.
|
||||
@ -39,6 +39,7 @@ class PredictionContainer;
|
||||
class GradientBooster : public Model, public Configurable {
|
||||
protected:
|
||||
GenericParameter const* ctx_;
|
||||
explicit GradientBooster(GenericParameter const* ctx) : ctx_{ctx} {}
|
||||
|
||||
public:
|
||||
/*! \brief virtual destructor */
|
||||
@ -208,9 +209,9 @@ class GradientBooster : public Model, public Configurable {
|
||||
*/
|
||||
struct GradientBoosterReg
|
||||
: public dmlc::FunctionRegEntryBase<
|
||||
GradientBoosterReg,
|
||||
std::function<GradientBooster* (LearnerModelParam const* learner_model_param)> > {
|
||||
};
|
||||
GradientBoosterReg,
|
||||
std::function<GradientBooster*(LearnerModelParam const* learner_model_param,
|
||||
GenericParameter const* ctx)> > {};
|
||||
|
||||
/*!
|
||||
* \brief Macro to register gradient booster.
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2015-2021 by Contributors
|
||||
* Copyright 2015-2022 by XGBoost Contributors
|
||||
* \file learner.h
|
||||
* \brief Learner interface that integrates objective, gbm and evaluation together.
|
||||
* This is the user facing XGBoost training module.
|
||||
@ -280,8 +280,10 @@ class Learner : public Model, public Configurable, public dmlc::Serializable {
|
||||
* \return Created learner.
|
||||
*/
|
||||
static Learner* Create(const std::vector<std::shared_ptr<DMatrix> >& cache_data);
|
||||
|
||||
virtual GenericParameter const& GetGenericParameter() const = 0;
|
||||
/**
|
||||
* \brief Return the context object of this Booster.
|
||||
*/
|
||||
virtual GenericParameter const* Ctx() const = 0;
|
||||
/*!
|
||||
* \brief Get configuration arguments currently stored by the learner
|
||||
* \return Key-value pairs representing configuration arguments
|
||||
|
||||
@ -177,6 +177,7 @@ void ParallelFor(Index size, int32_t n_threads, Sched sched, Func fn) {
|
||||
using OmpInd = Index;
|
||||
#endif
|
||||
OmpInd length = static_cast<OmpInd>(size);
|
||||
CHECK_GE(n_threads, 1);
|
||||
|
||||
dmlc::OMPException exc;
|
||||
switch (sched.sched) {
|
||||
@ -227,42 +228,16 @@ void ParallelFor(Index size, int32_t n_threads, Sched sched, Func fn) {
|
||||
}
|
||||
|
||||
template <typename Index, typename Func>
|
||||
void ParallelFor(Index size, size_t n_threads, Func fn) {
|
||||
void ParallelFor(Index size, int32_t n_threads, Func fn) {
|
||||
ParallelFor(size, n_threads, Sched::Static(), fn);
|
||||
}
|
||||
|
||||
// FIXME(jiamingy): Remove this function to get rid of `omp_set_num_threads`, which sets a
|
||||
// global variable in runtime and affects other programs in the same process.
|
||||
template <typename Index, typename Func>
|
||||
void ParallelFor(Index size, Func fn) {
|
||||
ParallelFor(size, omp_get_max_threads(), Sched::Static(), fn);
|
||||
} // !defined(_OPENMP)
|
||||
|
||||
|
||||
inline int32_t OmpGetThreadLimit() {
|
||||
int32_t limit = omp_get_thread_limit();
|
||||
CHECK_GE(limit, 1) << "Invalid thread limit for OpenMP.";
|
||||
return limit;
|
||||
}
|
||||
|
||||
/* \brief Configure parallel threads.
|
||||
*
|
||||
* \param p_threads Number of threads, when it's less than or equal to 0, this function
|
||||
* will change it to number of process on system.
|
||||
*
|
||||
* \return Global openmp max threads before configuration.
|
||||
*/
|
||||
inline int32_t OmpSetNumThreads(int32_t* p_threads) {
|
||||
auto& threads = *p_threads;
|
||||
int32_t nthread_original = omp_get_max_threads();
|
||||
if (threads <= 0) {
|
||||
threads = omp_get_num_procs();
|
||||
}
|
||||
threads = std::min(threads, OmpGetThreadLimit());
|
||||
omp_set_num_threads(threads);
|
||||
return nthread_original;
|
||||
}
|
||||
|
||||
inline int32_t OmpGetNumThreads(int32_t n_threads) {
|
||||
if (n_threads <= 0) {
|
||||
n_threads = std::min(omp_get_num_procs(), omp_get_max_threads());
|
||||
|
||||
@ -69,11 +69,12 @@ class IterativeDeviceDMatrix : public DMatrix {
|
||||
|
||||
bool SingleColBlock() const override { return false; }
|
||||
|
||||
MetaInfo& Info() override {
|
||||
return info_;
|
||||
}
|
||||
MetaInfo const& Info() const override {
|
||||
return info_;
|
||||
MetaInfo &Info() override { return info_; }
|
||||
MetaInfo const &Info() const override { return info_; }
|
||||
|
||||
GenericParameter const *Ctx() const override {
|
||||
LOG(FATAL) << "`IterativeDMatrix` doesn't have context.";
|
||||
return nullptr;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -79,6 +79,11 @@ class DMatrixProxy : public DMatrix {
|
||||
|
||||
MetaInfo& Info() override { return info_; }
|
||||
MetaInfo const& Info() const override { return info_; }
|
||||
GenericParameter const* Ctx() const override {
|
||||
LOG(FATAL) << "`ProxyDMatrix` doesn't have context.";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
bool SingleColBlock() const override { return true; }
|
||||
bool EllpackExists() const override { return true; }
|
||||
bool SparsePageExists() const override { return false; }
|
||||
|
||||
@ -30,8 +30,8 @@ class SimpleDMatrix : public DMatrix {
|
||||
void SaveToLocalFile(const std::string& fname);
|
||||
|
||||
MetaInfo& Info() override;
|
||||
|
||||
const MetaInfo& Info() const override;
|
||||
GenericParameter const* Ctx() const override { return &ctx_; }
|
||||
|
||||
bool SingleColBlock() const override { return true; }
|
||||
DMatrix* Slice(common::Span<int32_t const> ridxs) override;
|
||||
|
||||
@ -99,8 +99,8 @@ class SparsePageDMatrix : public DMatrix {
|
||||
}
|
||||
|
||||
MetaInfo& Info() override;
|
||||
|
||||
const MetaInfo& Info() const override;
|
||||
GenericParameter const* Ctx() const override { return &ctx_; }
|
||||
|
||||
bool SingleColBlock() const override { return false; }
|
||||
DMatrix *Slice(common::Span<int32_t const>) override {
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2014-2021 by Contributors
|
||||
* Copyright 2014-2022 by XGBoost Contributors
|
||||
* \file gblinear.cc
|
||||
* \brief Implementation of Linear booster, with L1/L2 regularization: Elastic Net
|
||||
* the update rule is parallel coordinate descent (shotgun)
|
||||
@ -71,8 +71,9 @@ void LinearCheckLayer(unsigned layer_begin) {
|
||||
*/
|
||||
class GBLinear : public GradientBooster {
|
||||
public:
|
||||
explicit GBLinear(LearnerModelParam const* learner_model_param)
|
||||
: learner_model_param_{learner_model_param},
|
||||
explicit GBLinear(LearnerModelParam const* learner_model_param, GenericParameter const* ctx)
|
||||
: GradientBooster{ctx},
|
||||
learner_model_param_{learner_model_param},
|
||||
model_{learner_model_param},
|
||||
previous_model_{learner_model_param},
|
||||
sum_instance_weight_(0),
|
||||
@ -190,7 +191,7 @@ class GBLinear : public GradientBooster {
|
||||
// parallel over local batch
|
||||
const auto nsize = static_cast<bst_omp_uint>(batch.Size());
|
||||
auto page = batch.GetView();
|
||||
common::ParallelFor(nsize, [&](bst_omp_uint i) {
|
||||
common::ParallelFor(nsize, ctx_->Threads(), [&](bst_omp_uint i) {
|
||||
auto inst = page[i];
|
||||
auto row_idx = static_cast<size_t>(batch.base_rowid + i);
|
||||
// loop over output groups
|
||||
@ -282,7 +283,7 @@ class GBLinear : public GradientBooster {
|
||||
if (base_margin.Size() != 0) {
|
||||
CHECK_EQ(base_margin.Size(), nsize * ngroup);
|
||||
}
|
||||
common::ParallelFor(nsize, [&](omp_ulong i) {
|
||||
common::ParallelFor(nsize, ctx_->Threads(), [&](omp_ulong i) {
|
||||
const size_t ridx = page.base_rowid + i;
|
||||
// loop over output groups
|
||||
for (int gid = 0; gid < ngroup; ++gid) {
|
||||
@ -351,8 +352,8 @@ DMLC_REGISTER_PARAMETER(GBLinearTrainParam);
|
||||
|
||||
XGBOOST_REGISTER_GBM(GBLinear, "gblinear")
|
||||
.describe("Linear booster, implement generalized linear model.")
|
||||
.set_body([](LearnerModelParam const* booster_config) {
|
||||
return new GBLinear(booster_config);
|
||||
.set_body([](LearnerModelParam const* booster_config, GenericParameter const* ctx) {
|
||||
return new GBLinear(booster_config, ctx);
|
||||
});
|
||||
} // namespace gbm
|
||||
} // namespace xgboost
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2015-2020 by Contributors
|
||||
* Copyright 2015-2022 by XGBoost Contributors
|
||||
* \file gbm.cc
|
||||
* \brief Registry of gradient boosters.
|
||||
*/
|
||||
@ -17,16 +17,13 @@ DMLC_REGISTRY_ENABLE(::xgboost::GradientBoosterReg);
|
||||
} // namespace dmlc
|
||||
|
||||
namespace xgboost {
|
||||
GradientBooster* GradientBooster::Create(
|
||||
const std::string& name,
|
||||
GenericParameter const* generic_param,
|
||||
LearnerModelParam const* learner_model_param) {
|
||||
GradientBooster* GradientBooster::Create(const std::string& name, GenericParameter const* ctx,
|
||||
LearnerModelParam const* learner_model_param) {
|
||||
auto *e = ::dmlc::Registry< ::xgboost::GradientBoosterReg>::Get()->Find(name);
|
||||
if (e == nullptr) {
|
||||
LOG(FATAL) << "Unknown gbm type " << name;
|
||||
}
|
||||
auto p_bst = (e->body)(learner_model_param);
|
||||
p_bst->ctx_ = generic_param;
|
||||
auto p_bst = (e->body)(learner_model_param, ctx);
|
||||
return p_bst;
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
@ -636,8 +636,8 @@ void GPUDartInplacePredictInc(common::Span<float> out_predts,
|
||||
|
||||
class Dart : public GBTree {
|
||||
public:
|
||||
explicit Dart(LearnerModelParam const* booster_config) :
|
||||
GBTree(booster_config) {}
|
||||
explicit Dart(LearnerModelParam const* booster_config, GenericParameter const* ctx)
|
||||
: GBTree(booster_config, ctx) {}
|
||||
|
||||
void Configure(const Args& cfg) override {
|
||||
GBTree::Configure(cfg);
|
||||
@ -1018,16 +1018,16 @@ DMLC_REGISTER_PARAMETER(GBTreeTrainParam);
|
||||
DMLC_REGISTER_PARAMETER(DartTrainParam);
|
||||
|
||||
XGBOOST_REGISTER_GBM(GBTree, "gbtree")
|
||||
.describe("Tree booster, gradient boosted trees.")
|
||||
.set_body([](LearnerModelParam const* booster_config) {
|
||||
auto* p = new GBTree(booster_config);
|
||||
return p;
|
||||
});
|
||||
.describe("Tree booster, gradient boosted trees.")
|
||||
.set_body([](LearnerModelParam const* booster_config, GenericParameter const* ctx) {
|
||||
auto* p = new GBTree(booster_config, ctx);
|
||||
return p;
|
||||
});
|
||||
XGBOOST_REGISTER_GBM(Dart, "dart")
|
||||
.describe("Tree booster, dart.")
|
||||
.set_body([](LearnerModelParam const* booster_config) {
|
||||
GBTree* p = new Dart(booster_config);
|
||||
return p;
|
||||
});
|
||||
.describe("Tree booster, dart.")
|
||||
.set_body([](LearnerModelParam const* booster_config, GenericParameter const* ctx) {
|
||||
GBTree* p = new Dart(booster_config, ctx);
|
||||
return p;
|
||||
});
|
||||
} // namespace gbm
|
||||
} // namespace xgboost
|
||||
|
||||
@ -202,8 +202,8 @@ inline bool SliceTrees(int32_t layer_begin, int32_t layer_end, int32_t step,
|
||||
// gradient boosted trees
|
||||
class GBTree : public GradientBooster {
|
||||
public:
|
||||
explicit GBTree(LearnerModelParam const* booster_config) :
|
||||
model_(booster_config) {}
|
||||
explicit GBTree(LearnerModelParam const* booster_config, GenericParameter const* ctx)
|
||||
: GradientBooster{ctx}, model_(booster_config, ctx_) {}
|
||||
|
||||
void Configure(const Args& cfg) override;
|
||||
// Revise `tree_method` and `updater` parameters after seeing the training
|
||||
|
||||
@ -69,7 +69,8 @@ void GBTreeModel::SaveModel(Json* p_out) const {
|
||||
out["gbtree_model_param"] = ToJson(param);
|
||||
std::vector<Json> trees_json(trees.size());
|
||||
|
||||
common::ParallelFor(trees.size(), omp_get_max_threads(), [&](auto t) {
|
||||
CHECK(ctx_);
|
||||
common::ParallelFor(trees.size(), ctx_->Threads(), [&](auto t) {
|
||||
auto const& tree = trees[t];
|
||||
Json tree_json{Object()};
|
||||
tree->SaveModel(&tree_json);
|
||||
@ -95,7 +96,8 @@ void GBTreeModel::LoadModel(Json const& in) {
|
||||
auto const& trees_json = get<Array const>(in["trees"]);
|
||||
trees.resize(trees_json.size());
|
||||
|
||||
common::ParallelFor(trees_json.size(), omp_get_max_threads(), [&](auto t) {
|
||||
CHECK(ctx_);
|
||||
common::ParallelFor(trees_json.size(), ctx_->Threads(), [&](auto t) {
|
||||
auto tree_id = get<Integer>(trees_json[t]["id"]);
|
||||
trees.at(tree_id).reset(new RegTree());
|
||||
trees.at(tree_id)->LoadModel(trees_json[t]);
|
||||
|
||||
@ -83,8 +83,8 @@ struct GBTreeModelParam : public dmlc::Parameter<GBTreeModelParam> {
|
||||
|
||||
struct GBTreeModel : public Model {
|
||||
public:
|
||||
explicit GBTreeModel(LearnerModelParam const* learner_model) :
|
||||
learner_model_param{learner_model} {}
|
||||
explicit GBTreeModel(LearnerModelParam const* learner_model, GenericParameter const* ctx)
|
||||
: learner_model_param{learner_model}, ctx_{ctx} {}
|
||||
void Configure(const Args& cfg) {
|
||||
// initialize model parameters if not yet been initialized.
|
||||
if (trees.size() == 0) {
|
||||
@ -135,6 +135,9 @@ struct GBTreeModel : public Model {
|
||||
std::vector<std::unique_ptr<RegTree> > trees_to_update;
|
||||
/*! \brief some information indicator of the tree, reserved */
|
||||
std::vector<int> tree_info;
|
||||
|
||||
private:
|
||||
GenericParameter const* ctx_;
|
||||
};
|
||||
} // namespace gbm
|
||||
} // namespace xgboost
|
||||
|
||||
@ -331,7 +331,6 @@ class LearnerConfiguration : public Learner {
|
||||
generic_parameters_.UpdateAllowUnknown(args);
|
||||
|
||||
ConsoleLogger::Configure(args);
|
||||
common::OmpSetNumThreads(&generic_parameters_.nthread);
|
||||
|
||||
// add additional parameters
|
||||
// These are cosntraints that need to be satisfied.
|
||||
@ -522,9 +521,7 @@ class LearnerConfiguration : public Learner {
|
||||
return cfg_;
|
||||
}
|
||||
|
||||
GenericParameter const& GetGenericParameter() const override {
|
||||
return generic_parameters_;
|
||||
}
|
||||
GenericParameter const* Ctx() const override { return &generic_parameters_; }
|
||||
|
||||
private:
|
||||
void ValidateParameters() {
|
||||
|
||||
@ -111,9 +111,8 @@ struct EvalAMS : public Metric {
|
||||
PredIndPairContainer rec(ndata);
|
||||
|
||||
const auto &h_preds = preds.ConstHostVector();
|
||||
common::ParallelFor(ndata, [&](bst_omp_uint i) {
|
||||
rec[i] = std::make_pair(h_preds[i], i);
|
||||
});
|
||||
common::ParallelFor(ndata, tparam_->Threads(),
|
||||
[&](bst_omp_uint i) { rec[i] = std::make_pair(h_preds[i], i); });
|
||||
XGBOOST_PARALLEL_SORT(rec.begin(), rec.end(), common::CmpFirst);
|
||||
auto ntop = static_cast<unsigned>(ratio_ * ndata);
|
||||
if (ntop == 0) ntop = ndata;
|
||||
|
||||
@ -826,7 +826,7 @@ class LambdaRankObj : public ObjFunction {
|
||||
out_gpair->Resize(preds.Size());
|
||||
|
||||
dmlc::OMPException exc;
|
||||
#pragma omp parallel
|
||||
#pragma omp parallel num_threads(ctx_->Threads())
|
||||
{
|
||||
exc.Run([&]() {
|
||||
// parallel construct, declare random number generator here, so that each
|
||||
|
||||
@ -14,15 +14,7 @@
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
|
||||
size_t GetNThreads() {
|
||||
size_t nthreads;
|
||||
#pragma omp parallel
|
||||
{
|
||||
#pragma omp master
|
||||
nthreads = omp_get_num_threads();
|
||||
}
|
||||
return nthreads;
|
||||
}
|
||||
size_t GetNThreads() { return common::OmpGetNumThreads(0); }
|
||||
|
||||
template <typename GradientSumT>
|
||||
void ParallelGHistBuilderReset() {
|
||||
|
||||
@ -590,10 +590,8 @@ TEST(Json, DISABLED_RoundTripExhaustive) {
|
||||
}
|
||||
};
|
||||
int64_t int32_max = static_cast<int64_t>(std::numeric_limits<uint32_t>::max());
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (int64_t i = 0; i <= int32_max; ++i) {
|
||||
test(static_cast<uint32_t>(i));
|
||||
}
|
||||
GenericParameter ctx;
|
||||
common::ParallelFor(int32_max, ctx.Threads(), [&](auto i) { test(static_cast<uint32_t>(i)); });
|
||||
}
|
||||
|
||||
TEST(Json, TypedArray) {
|
||||
|
||||
@ -88,22 +88,5 @@ TEST(ParallelFor2dNonUniform, Test) {
|
||||
|
||||
omp_set_num_threads(old);
|
||||
}
|
||||
#if defined(_OPENMP)
|
||||
TEST(OmpSetNumThreads, Basic) {
|
||||
auto nthreads = 2;
|
||||
auto orgi = OmpSetNumThreads(&nthreads);
|
||||
ASSERT_EQ(omp_get_max_threads(), 2);
|
||||
nthreads = 0;
|
||||
OmpSetNumThreads(&nthreads);
|
||||
ASSERT_EQ(omp_get_max_threads(), omp_get_num_procs());
|
||||
nthreads = 1;
|
||||
OmpSetNumThreads(&nthreads);
|
||||
nthreads = 0;
|
||||
OmpSetNumThreads(&nthreads);
|
||||
ASSERT_EQ(omp_get_max_threads(), omp_get_num_procs());
|
||||
|
||||
omp_set_num_threads(orgi);
|
||||
}
|
||||
#endif // defined(_OPENMP)
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
|
||||
@ -506,8 +506,9 @@ std::unique_ptr<DMatrix> CreateSparsePageDMatrixWithRC(
|
||||
return dmat;
|
||||
}
|
||||
|
||||
gbm::GBTreeModel CreateTestModel(LearnerModelParam const* param, size_t n_classes) {
|
||||
gbm::GBTreeModel model(param);
|
||||
gbm::GBTreeModel CreateTestModel(LearnerModelParam const* param, GenericParameter const* ctx,
|
||||
size_t n_classes) {
|
||||
gbm::GBTreeModel model(param, ctx);
|
||||
|
||||
for (size_t i = 0; i < n_classes; ++i) {
|
||||
std::vector<std::unique_ptr<RegTree>> trees;
|
||||
|
||||
@ -357,7 +357,8 @@ std::unique_ptr<DMatrix> CreateSparsePageDMatrixWithRC(
|
||||
size_t n_rows, size_t n_cols, size_t page_size, bool deterministic,
|
||||
const dmlc::TemporaryDirectory& tempdir = dmlc::TemporaryDirectory());
|
||||
|
||||
gbm::GBTreeModel CreateTestModel(LearnerModelParam const* param, size_t n_classes = 1);
|
||||
gbm::GBTreeModel CreateTestModel(LearnerModelParam const* param, GenericParameter const* ctx,
|
||||
size_t n_classes = 1);
|
||||
|
||||
std::unique_ptr<GradientBooster> CreateTrainedGBM(
|
||||
std::string name, Args kwargs, size_t kRows, size_t kCols,
|
||||
|
||||
@ -25,7 +25,9 @@ TEST(CpuPredictor, Basic) {
|
||||
param.base_score = 0.0;
|
||||
param.num_output_group = 1;
|
||||
|
||||
gbm::GBTreeModel model = CreateTestModel(¶m);
|
||||
GenericParameter ctx;
|
||||
ctx.UpdateAllowUnknown(Args{});
|
||||
gbm::GBTreeModel model = CreateTestModel(¶m, &ctx);
|
||||
|
||||
auto dmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix();
|
||||
|
||||
@ -106,7 +108,9 @@ TEST(CpuPredictor, ExternalMemory) {
|
||||
param.num_feature = dmat->Info().num_col_;
|
||||
param.num_output_group = 1;
|
||||
|
||||
gbm::GBTreeModel model = CreateTestModel(¶m);
|
||||
GenericParameter ctx;
|
||||
ctx.UpdateAllowUnknown(Args{});
|
||||
gbm::GBTreeModel model = CreateTestModel(¶m, &ctx);
|
||||
|
||||
// Test predict batch
|
||||
PredictionCacheEntry out_predictions;
|
||||
|
||||
@ -38,7 +38,9 @@ TEST(GPUPredictor, Basic) {
|
||||
param.num_output_group = 1;
|
||||
param.base_score = 0.5;
|
||||
|
||||
gbm::GBTreeModel model = CreateTestModel(¶m);
|
||||
GenericParameter ctx;
|
||||
ctx.UpdateAllowUnknown(Args{});
|
||||
gbm::GBTreeModel model = CreateTestModel(¶m, &ctx);
|
||||
|
||||
// Test predict batch
|
||||
PredictionCacheEntry gpu_out_predictions;
|
||||
@ -100,7 +102,9 @@ TEST(GPUPredictor, ExternalMemoryTest) {
|
||||
param.num_output_group = n_classes;
|
||||
param.base_score = 0.5;
|
||||
|
||||
gbm::GBTreeModel model = CreateTestModel(¶m, n_classes);
|
||||
GenericParameter ctx;
|
||||
ctx.UpdateAllowUnknown(Args{});
|
||||
gbm::GBTreeModel model = CreateTestModel(¶m, &ctx, n_classes);
|
||||
std::vector<std::unique_ptr<DMatrix>> dmats;
|
||||
|
||||
dmats.push_back(CreateSparsePageDMatrix(400));
|
||||
@ -167,11 +171,17 @@ TEST(GpuPredictor, LesserFeatures) {
|
||||
// Very basic test of empty model
|
||||
TEST(GPUPredictor, ShapStump) {
|
||||
cudaSetDevice(0);
|
||||
|
||||
LearnerModelParam param;
|
||||
param.num_feature = 1;
|
||||
param.num_output_group = 1;
|
||||
param.base_score = 0.5;
|
||||
gbm::GBTreeModel model(¶m);
|
||||
|
||||
GenericParameter ctx;
|
||||
ctx.UpdateAllowUnknown(Args{});
|
||||
|
||||
gbm::GBTreeModel model(¶m, &ctx);
|
||||
|
||||
std::vector<std::unique_ptr<RegTree>> trees;
|
||||
trees.push_back(std::unique_ptr<RegTree>(new RegTree));
|
||||
model.CommitModel(std::move(trees), 0);
|
||||
@ -197,7 +207,12 @@ TEST(GPUPredictor, Shap) {
|
||||
param.num_feature = 1;
|
||||
param.num_output_group = 1;
|
||||
param.base_score = 0.5;
|
||||
gbm::GBTreeModel model(¶m);
|
||||
|
||||
GenericParameter ctx;
|
||||
ctx.UpdateAllowUnknown(Args{});
|
||||
|
||||
gbm::GBTreeModel model(¶m, &ctx);
|
||||
|
||||
std::vector<std::unique_ptr<RegTree>> trees;
|
||||
trees.push_back(std::unique_ptr<RegTree>(new RegTree));
|
||||
trees[0]->ExpandNode(0, 0, 0.5, true, 1.0, -1.0, 1.0, 0.0, 5.0, 2.0, 3.0);
|
||||
@ -249,7 +264,9 @@ TEST(GPUPredictor, PredictLeafBasic) {
|
||||
param.base_score = 0.0;
|
||||
param.num_output_group = 1;
|
||||
|
||||
gbm::GBTreeModel model = CreateTestModel(¶m);
|
||||
GenericParameter ctx;
|
||||
ctx.UpdateAllowUnknown(Args{});
|
||||
gbm::GBTreeModel model = CreateTestModel(¶m, &ctx);
|
||||
|
||||
HostDeviceVector<float> leaf_out_predictions;
|
||||
gpu_predictor->PredictLeaf(dmat.get(), &leaf_out_predictions, model);
|
||||
|
||||
@ -214,10 +214,11 @@ void TestCategoricalPrediction(std::string name) {
|
||||
float left_weight = 1.3f;
|
||||
float right_weight = 1.7f;
|
||||
|
||||
gbm::GBTreeModel model(¶m);
|
||||
GenericParameter ctx;
|
||||
ctx.UpdateAllowUnknown(Args{});
|
||||
gbm::GBTreeModel model(¶m, &ctx);
|
||||
GBTreeModelForTest(&model, split_ind, split_cat, left_weight, right_weight);
|
||||
|
||||
GenericParameter ctx;
|
||||
ctx.UpdateAllowUnknown(Args{{"gpu_id", "0"}});
|
||||
std::unique_ptr<Predictor> predictor{Predictor::Create(name.c_str(), &ctx)};
|
||||
|
||||
@ -257,13 +258,14 @@ void TestCategoricalPredictLeaf(StringView name) {
|
||||
float left_weight = 1.3f;
|
||||
float right_weight = 1.7f;
|
||||
|
||||
gbm::GBTreeModel model(¶m);
|
||||
GenericParameter ctx;
|
||||
ctx.UpdateAllowUnknown(Args{});
|
||||
|
||||
gbm::GBTreeModel model(¶m, &ctx);
|
||||
GBTreeModelForTest(&model, split_ind, split_cat, left_weight, right_weight);
|
||||
|
||||
GenericParameter runtime;
|
||||
runtime.gpu_id = 0;
|
||||
std::unique_ptr<Predictor> predictor{
|
||||
Predictor::Create(name.c_str(), &runtime)};
|
||||
ctx.gpu_id = 0;
|
||||
std::unique_ptr<Predictor> predictor{Predictor::Create(name.c_str(), &ctx)};
|
||||
|
||||
std::vector<float> row(kCols);
|
||||
row[split_ind] = split_cat;
|
||||
|
||||
@ -23,7 +23,9 @@ void TestPredictionFromGradientIndex(std::string name, size_t rows, size_t cols,
|
||||
std::unique_ptr<Predictor>(Predictor::Create(name, &lparam));
|
||||
predictor->Configure({});
|
||||
|
||||
gbm::GBTreeModel model = CreateTestModel(¶m, kClasses);
|
||||
GenericParameter ctx;
|
||||
ctx.UpdateAllowUnknown(Args{});
|
||||
gbm::GBTreeModel model = CreateTestModel(¶m, &ctx, kClasses);
|
||||
|
||||
{
|
||||
auto p_precise = RandomDataGenerator(rows, cols, 0).GenerateDMatrix();
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2017-2020 XGBoost contributors
|
||||
* Copyright 2017-2022 by XGBoost contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
#include <vector>
|
||||
@ -284,27 +284,27 @@ TEST(Learner, GPUConfiguration) {
|
||||
learner->SetParams({Arg{"booster", "gblinear"},
|
||||
Arg{"updater", "gpu_coord_descent"}});
|
||||
learner->UpdateOneIter(0, p_dmat);
|
||||
ASSERT_EQ(learner->GetGenericParameter().gpu_id, 0);
|
||||
ASSERT_EQ(learner->Ctx()->gpu_id, 0);
|
||||
}
|
||||
{
|
||||
std::unique_ptr<Learner> learner {Learner::Create(mat)};
|
||||
learner->SetParams({Arg{"tree_method", "gpu_hist"}});
|
||||
learner->UpdateOneIter(0, p_dmat);
|
||||
ASSERT_EQ(learner->GetGenericParameter().gpu_id, 0);
|
||||
ASSERT_EQ(learner->Ctx()->gpu_id, 0);
|
||||
}
|
||||
{
|
||||
std::unique_ptr<Learner> learner {Learner::Create(mat)};
|
||||
learner->SetParams({Arg{"tree_method", "gpu_hist"},
|
||||
Arg{"gpu_id", "-1"}});
|
||||
learner->UpdateOneIter(0, p_dmat);
|
||||
ASSERT_EQ(learner->GetGenericParameter().gpu_id, 0);
|
||||
ASSERT_EQ(learner->Ctx()->gpu_id, 0);
|
||||
}
|
||||
{
|
||||
// with CPU algorithm
|
||||
std::unique_ptr<Learner> learner {Learner::Create(mat)};
|
||||
learner->SetParams({Arg{"tree_method", "hist"}});
|
||||
learner->UpdateOneIter(0, p_dmat);
|
||||
ASSERT_EQ(learner->GetGenericParameter().gpu_id, -1);
|
||||
ASSERT_EQ(learner->Ctx()->gpu_id, -1);
|
||||
}
|
||||
{
|
||||
// with CPU algorithm, but `gpu_id` takes priority
|
||||
@ -312,7 +312,7 @@ TEST(Learner, GPUConfiguration) {
|
||||
learner->SetParams({Arg{"tree_method", "hist"},
|
||||
Arg{"gpu_id", "0"}});
|
||||
learner->UpdateOneIter(0, p_dmat);
|
||||
ASSERT_EQ(learner->GetGenericParameter().gpu_id, 0);
|
||||
ASSERT_EQ(learner->Ctx()->gpu_id, 0);
|
||||
}
|
||||
{
|
||||
// With CPU algorithm but GPU Predictor, this is to simulate when
|
||||
@ -322,7 +322,7 @@ TEST(Learner, GPUConfiguration) {
|
||||
learner->SetParams({Arg{"tree_method", "hist"},
|
||||
Arg{"predictor", "gpu_predictor"}});
|
||||
learner->UpdateOneIter(0, p_dmat);
|
||||
ASSERT_EQ(learner->GetGenericParameter().gpu_id, 0);
|
||||
ASSERT_EQ(learner->Ctx()->gpu_id, 0);
|
||||
}
|
||||
}
|
||||
#endif // defined(XGBOOST_USE_CUDA)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user