Replaced std::vector-based interfaces with HostDeviceVector-based interfaces. (#3116)
* Replaced std::vector-based interfaces with HostDeviceVector-based interfaces. - replacement was performed in the learner, boosters, predictors, updaters, and objective functions - only interfaces used in training were replaced; interfaces like PredictInstance() still use std::vector - refactoring necessary for replacement of interfaces was also performed, such as using HostDeviceVector in prediction cache * HostDeviceVector-based interfaces for custom objective function example plugin.
This commit is contained in:
committed by
Rory Mitchell
parent
11bfa8584d
commit
d5992dd881
@@ -191,9 +191,9 @@ struct XGBAPIThreadLocalEntry {
|
||||
/*! \brief result holder for returning string pointers */
|
||||
std::vector<const char *> ret_vec_charp;
|
||||
/*! \brief returning float vector. */
|
||||
std::vector<bst_float> ret_vec_float;
|
||||
HostDeviceVector<bst_float> ret_vec_float;
|
||||
/*! \brief temp variable of gradient pairs. */
|
||||
std::vector<bst_gpair> tmp_gpair;
|
||||
HostDeviceVector<bst_gpair> tmp_gpair;
|
||||
};
|
||||
|
||||
// define the threadlocal store.
|
||||
@@ -705,14 +705,15 @@ XGB_DLL int XGBoosterBoostOneIter(BoosterHandle handle,
|
||||
bst_float *grad,
|
||||
bst_float *hess,
|
||||
xgboost::bst_ulong len) {
|
||||
std::vector<bst_gpair>& tmp_gpair = XGBAPIThreadLocalStore::Get()->tmp_gpair;
|
||||
HostDeviceVector<bst_gpair>& tmp_gpair = XGBAPIThreadLocalStore::Get()->tmp_gpair;
|
||||
API_BEGIN();
|
||||
Booster* bst = static_cast<Booster*>(handle);
|
||||
std::shared_ptr<DMatrix>* dtr =
|
||||
static_cast<std::shared_ptr<DMatrix>*>(dtrain);
|
||||
tmp_gpair.resize(len);
|
||||
std::vector<bst_gpair>& tmp_gpair_h = tmp_gpair.data_h();
|
||||
for (xgboost::bst_ulong i = 0; i < len; ++i) {
|
||||
tmp_gpair[i] = bst_gpair(grad[i], hess[i]);
|
||||
tmp_gpair_h[i] = bst_gpair(grad[i], hess[i]);
|
||||
}
|
||||
|
||||
bst->LazyInit();
|
||||
@@ -749,7 +750,8 @@ XGB_DLL int XGBoosterPredict(BoosterHandle handle,
|
||||
unsigned ntree_limit,
|
||||
xgboost::bst_ulong *len,
|
||||
const bst_float **out_result) {
|
||||
std::vector<bst_float>& preds = XGBAPIThreadLocalStore::Get()->ret_vec_float;
|
||||
HostDeviceVector<bst_float>& preds =
|
||||
XGBAPIThreadLocalStore::Get()->ret_vec_float;
|
||||
API_BEGIN();
|
||||
Booster *bst = static_cast<Booster*>(handle);
|
||||
bst->LazyInit();
|
||||
@@ -761,7 +763,7 @@ XGB_DLL int XGBoosterPredict(BoosterHandle handle,
|
||||
(option_mask & 4) != 0,
|
||||
(option_mask & 8) != 0,
|
||||
(option_mask & 16) != 0);
|
||||
*out_result = dmlc::BeginPtr(preds);
|
||||
*out_result = dmlc::BeginPtr(preds.data_h());
|
||||
*len = static_cast<xgboost::bst_ulong>(preds.size());
|
||||
API_END();
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user