Define multi-strategy parameter. (#8890)

This commit is contained in:
Jiaming Yuan
2023-03-11 02:58:01 +08:00
committed by GitHub
parent 6deaec8027
commit 2aa838c75e
6 changed files with 221 additions and 134 deletions

View File

@@ -3,18 +3,19 @@
*/
#include "xgboost/predictor.h"
#include <dmlc/registry.h>
#include <dmlc/registry.h> // for DMLC_REGISTRY_LINK_TAG
#include <string> // std::string
#include <cstdint> // for int32_t
#include <string> // for string, to_string
#include "../gbm/gbtree.h" // GBTreeModel
#include "xgboost/base.h" // bst_row_t,bst_group_t
#include "xgboost/context.h" // Context
#include "xgboost/data.h" // MetaInfo
#include "xgboost/host_device_vector.h" // HostDeviceVector
#include "xgboost/learner.h" // LearnerModelParam
#include "xgboost/linalg.h" // Tensor
#include "xgboost/logging.h"
#include "../gbm/gbtree_model.h" // for GBTreeModel
#include "xgboost/base.h" // for bst_float, Args, bst_group_t, bst_row_t
#include "xgboost/context.h" // for Context
#include "xgboost/data.h" // for MetaInfo
#include "xgboost/host_device_vector.h" // for HostDeviceVector
#include "xgboost/learner.h" // for LearnerModelParam
#include "xgboost/linalg.h" // for Tensor, TensorView
#include "xgboost/logging.h" // for CHECK_EQ, CHECK_NE, LOG
namespace dmlc {
DMLC_REGISTRY_ENABLE(::xgboost::PredictorReg);
@@ -45,15 +46,16 @@ void ValidateBaseMarginShape(linalg::Tensor<float, D> const& margin, bst_row_t n
void Predictor::InitOutPredictions(const MetaInfo& info, HostDeviceVector<bst_float>* out_preds,
const gbm::GBTreeModel& model) const {
CHECK_NE(model.learner_model_param->num_output_group, 0);
size_t n_classes = model.learner_model_param->num_output_group;
size_t n = n_classes * info.num_row_;
std::size_t n{model.learner_model_param->OutputLength() * info.num_row_};
const HostDeviceVector<bst_float>* base_margin = info.base_margin_.Data();
if (ctx_->gpu_id >= 0) {
out_preds->SetDevice(ctx_->gpu_id);
}
if (!base_margin->Empty()) {
out_preds->Resize(n);
ValidateBaseMarginShape(info.base_margin_, info.num_row_, n_classes);
ValidateBaseMarginShape(info.base_margin_, info.num_row_,
model.learner_model_param->OutputLength());
out_preds->Copy(*base_margin);
} else {
// cannot rely on the Resize to fill as it might skip if the size is already correct.
@@ -64,12 +66,10 @@ void Predictor::InitOutPredictions(const MetaInfo& info, HostDeviceVector<bst_fl
}
} // namespace xgboost
namespace xgboost {
namespace predictor {
namespace xgboost::predictor {
// List of files that will be force linked in static links.
#ifdef XGBOOST_USE_CUDA
DMLC_REGISTRY_LINK_TAG(gpu_predictor);
#endif // XGBOOST_USE_CUDA
DMLC_REGISTRY_LINK_TAG(cpu_predictor);
} // namespace predictor
} // namespace xgboost
} // namespace xgboost::predictor