Fix inplace predict missing value. (#6787)

This commit is contained in:
Jiaming Yuan
2021-03-27 05:36:10 +08:00
committed by GitHub
parent 5c87c2bba8
commit a59c7323b4
8 changed files with 97 additions and 33 deletions

View File

@@ -255,7 +255,7 @@ XGB_DLL int XGDMatrixCreateFromCSR(char const *indptr,
data::CSRArrayAdapter adapter(StringView{indptr}, StringView{indices},
StringView{data}, ncol);
auto config = Json::Load(StringView{c_json_config});
float missing = get<Number const>(config["missing"]);
float missing = GetMissing(config);
auto nthread = get<Integer const>(config["nthread"]);
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(&adapter, missing, nthread));
API_END();
@@ -683,8 +683,8 @@ void InplacePredictImpl(std::shared_ptr<T> x, std::shared_ptr<DMatrix> p_m,
HostDeviceVector<float>* p_predt { nullptr };
auto type = PredictionType(get<Integer const>(config["type"]));
learner->InplacePredict(x, p_m, type, get<Number const>(config["missing"]),
&p_predt,
float missing = GetMissing(config);
learner->InplacePredict(x, p_m, type, missing, &p_predt,
get<Integer const>(config["iteration_begin"]),
get<Integer const>(config["iteration_end"]));
CHECK(p_predt);

View File

@@ -48,8 +48,9 @@ int InplacePreidctCuda(BoosterHandle handle, char const *c_json_strs,
auto x = std::make_shared<T>(json_str);
HostDeviceVector<float> *p_predt{nullptr};
auto type = PredictionType(get<Integer const>(config["type"]));
learner->InplacePredict(x, p_m, type, get<Number const>(config["missing"]),
&p_predt,
float missing = GetMissing(config);
learner->InplacePredict(x, p_m, type, missing, &p_predt,
get<Integer const>(config["iteration_begin"]),
get<Integer const>(config["iteration_end"]));
CHECK(p_predt);

View File

@@ -11,6 +11,9 @@
#include "xgboost/logging.h"
#include "xgboost/json.h"
#include "xgboost/learner.h"
#include "xgboost/c_api.h"
#include "c_api_error.h"
namespace xgboost {
/* \brief Determine the output shape of prediction.
@@ -141,5 +144,19 @@ inline uint32_t GetIterationFromTreeLimit(uint32_t ntree_limit, Learner *learner
}
return ntree_limit;
}
inline float GetMissing(Json const &config) {
float missing;
auto const& j_missing = config["missing"];
if (IsA<Number const>(j_missing)) {
missing = get<Number const>(j_missing);
} else if (IsA<Integer const>(j_missing)) {
missing = get<Integer const>(j_missing);
} else {
missing = nan("");
LOG(FATAL) << "Invalid missing value: " << j_missing;
}
return missing;
}
} // namespace xgboost
#endif // XGBOOST_C_API_C_API_UTILS_H_