Fix inplace predict missing value. (#6787)
This commit is contained in:
@@ -255,7 +255,7 @@ XGB_DLL int XGDMatrixCreateFromCSR(char const *indptr,
|
||||
data::CSRArrayAdapter adapter(StringView{indptr}, StringView{indices},
|
||||
StringView{data}, ncol);
|
||||
auto config = Json::Load(StringView{c_json_config});
|
||||
float missing = get<Number const>(config["missing"]);
|
||||
float missing = GetMissing(config);
|
||||
auto nthread = get<Integer const>(config["nthread"]);
|
||||
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(&adapter, missing, nthread));
|
||||
API_END();
|
||||
@@ -683,8 +683,8 @@ void InplacePredictImpl(std::shared_ptr<T> x, std::shared_ptr<DMatrix> p_m,
|
||||
|
||||
HostDeviceVector<float>* p_predt { nullptr };
|
||||
auto type = PredictionType(get<Integer const>(config["type"]));
|
||||
learner->InplacePredict(x, p_m, type, get<Number const>(config["missing"]),
|
||||
&p_predt,
|
||||
float missing = GetMissing(config);
|
||||
learner->InplacePredict(x, p_m, type, missing, &p_predt,
|
||||
get<Integer const>(config["iteration_begin"]),
|
||||
get<Integer const>(config["iteration_end"]));
|
||||
CHECK(p_predt);
|
||||
|
||||
@@ -48,8 +48,9 @@ int InplacePreidctCuda(BoosterHandle handle, char const *c_json_strs,
|
||||
auto x = std::make_shared<T>(json_str);
|
||||
HostDeviceVector<float> *p_predt{nullptr};
|
||||
auto type = PredictionType(get<Integer const>(config["type"]));
|
||||
learner->InplacePredict(x, p_m, type, get<Number const>(config["missing"]),
|
||||
&p_predt,
|
||||
float missing = GetMissing(config);
|
||||
|
||||
learner->InplacePredict(x, p_m, type, missing, &p_predt,
|
||||
get<Integer const>(config["iteration_begin"]),
|
||||
get<Integer const>(config["iteration_end"]));
|
||||
CHECK(p_predt);
|
||||
|
||||
@@ -11,6 +11,9 @@
|
||||
#include "xgboost/logging.h"
|
||||
#include "xgboost/json.h"
|
||||
#include "xgboost/learner.h"
|
||||
#include "xgboost/c_api.h"
|
||||
|
||||
#include "c_api_error.h"
|
||||
|
||||
namespace xgboost {
|
||||
/* \brief Determine the output shape of prediction.
|
||||
@@ -141,5 +144,19 @@ inline uint32_t GetIterationFromTreeLimit(uint32_t ntree_limit, Learner *learner
|
||||
}
|
||||
return ntree_limit;
|
||||
}
|
||||
|
||||
inline float GetMissing(Json const &config) {
|
||||
float missing;
|
||||
auto const& j_missing = config["missing"];
|
||||
if (IsA<Number const>(j_missing)) {
|
||||
missing = get<Number const>(j_missing);
|
||||
} else if (IsA<Integer const>(j_missing)) {
|
||||
missing = get<Integer const>(j_missing);
|
||||
} else {
|
||||
missing = nan("");
|
||||
LOG(FATAL) << "Invalid missing value: " << j_missing;
|
||||
}
|
||||
return missing;
|
||||
}
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_C_API_C_API_UTILS_H_
|
||||
|
||||
Reference in New Issue
Block a user