Define multi-strategy parameter. (#8890)
This commit is contained in:
@@ -3,18 +3,19 @@
|
||||
*/
|
||||
#include "xgboost/predictor.h"
|
||||
|
||||
#include <dmlc/registry.h>
|
||||
#include <dmlc/registry.h> // for DMLC_REGISTRY_LINK_TAG
|
||||
|
||||
#include <string> // std::string
|
||||
#include <cstdint> // for int32_t
|
||||
#include <string> // for string, to_string
|
||||
|
||||
#include "../gbm/gbtree.h" // GBTreeModel
|
||||
#include "xgboost/base.h" // bst_row_t,bst_group_t
|
||||
#include "xgboost/context.h" // Context
|
||||
#include "xgboost/data.h" // MetaInfo
|
||||
#include "xgboost/host_device_vector.h" // HostDeviceVector
|
||||
#include "xgboost/learner.h" // LearnerModelParam
|
||||
#include "xgboost/linalg.h" // Tensor
|
||||
#include "xgboost/logging.h"
|
||||
#include "../gbm/gbtree_model.h" // for GBTreeModel
|
||||
#include "xgboost/base.h" // for bst_float, Args, bst_group_t, bst_row_t
|
||||
#include "xgboost/context.h" // for Context
|
||||
#include "xgboost/data.h" // for MetaInfo
|
||||
#include "xgboost/host_device_vector.h" // for HostDeviceVector
|
||||
#include "xgboost/learner.h" // for LearnerModelParam
|
||||
#include "xgboost/linalg.h" // for Tensor, TensorView
|
||||
#include "xgboost/logging.h" // for CHECK_EQ, CHECK_NE, LOG
|
||||
|
||||
namespace dmlc {
|
||||
DMLC_REGISTRY_ENABLE(::xgboost::PredictorReg);
|
||||
@@ -45,15 +46,16 @@ void ValidateBaseMarginShape(linalg::Tensor<float, D> const& margin, bst_row_t n
|
||||
void Predictor::InitOutPredictions(const MetaInfo& info, HostDeviceVector<bst_float>* out_preds,
|
||||
const gbm::GBTreeModel& model) const {
|
||||
CHECK_NE(model.learner_model_param->num_output_group, 0);
|
||||
size_t n_classes = model.learner_model_param->num_output_group;
|
||||
size_t n = n_classes * info.num_row_;
|
||||
std::size_t n{model.learner_model_param->OutputLength() * info.num_row_};
|
||||
|
||||
const HostDeviceVector<bst_float>* base_margin = info.base_margin_.Data();
|
||||
if (ctx_->gpu_id >= 0) {
|
||||
out_preds->SetDevice(ctx_->gpu_id);
|
||||
}
|
||||
if (!base_margin->Empty()) {
|
||||
out_preds->Resize(n);
|
||||
ValidateBaseMarginShape(info.base_margin_, info.num_row_, n_classes);
|
||||
ValidateBaseMarginShape(info.base_margin_, info.num_row_,
|
||||
model.learner_model_param->OutputLength());
|
||||
out_preds->Copy(*base_margin);
|
||||
} else {
|
||||
// cannot rely on the Resize to fill as it might skip if the size is already correct.
|
||||
@@ -64,12 +66,10 @@ void Predictor::InitOutPredictions(const MetaInfo& info, HostDeviceVector<bst_fl
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
namespace xgboost {
|
||||
namespace predictor {
|
||||
namespace xgboost::predictor {
|
||||
// List of files that will be force linked in static links.
|
||||
#ifdef XGBOOST_USE_CUDA
|
||||
DMLC_REGISTRY_LINK_TAG(gpu_predictor);
|
||||
#endif // XGBOOST_USE_CUDA
|
||||
DMLC_REGISTRY_LINK_TAG(cpu_predictor);
|
||||
} // namespace predictor
|
||||
} // namespace xgboost
|
||||
} // namespace xgboost::predictor
|
||||
|
||||
Reference in New Issue
Block a user