Fix prediction heuristic (#5955)

* Relax check for prediction.
* Relax test in spark test.
* Add tests in C++.
This commit is contained in:
Jiaming Yuan 2020-07-29 19:24:07 +08:00 committed by GitHub
parent 5879acde9a
commit 75b8c22b0b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 103 additions and 28 deletions

View File

@ -26,5 +26,3 @@ data = "machine.txt.train"
eval[test] = "machine.txt.test" eval[test] = "machine.txt.test"
# The path of test data # The path of test data
test:data = "machine.txt.test" test:data = "machine.txt.test"

View File

@ -1,3 +1,2 @@
tracker.py tracker.py
build.sh build.sh

View File

@ -21,19 +21,23 @@ import org.apache.spark.Partitioner
import org.apache.spark.ml.feature.VectorAssembler import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.sql.SparkSession import org.apache.spark.sql.SparkSession
import org.scalatest.FunSuite import org.scalatest.FunSuite
import org.apache.spark.sql.functions._
import scala.util.Random import scala.util.Random
class FeatureSizeValidatingSuite extends FunSuite with PerTest { class FeatureSizeValidatingSuite extends FunSuite with PerTest {
test("transform throwing exception if feature size of dataset is different with model's") { test("transform throwing exception if feature size of dataset is greater than model's") {
val modelPath = getClass.getResource("/model/0.82/model").getPath val modelPath = getClass.getResource("/model/0.82/model").getPath
val model = XGBoostClassificationModel.read.load(modelPath) val model = XGBoostClassificationModel.read.load(modelPath)
val r = new Random(0) val r = new Random(0)
// 0.82/model was trained with 251 features. and transform will throw exception // 0.82/model was trained with 251 features. and transform will throw exception
// if feature size of data is not equal to 251 // if feature size of data is not equal to 251
val df = ss.createDataFrame(Seq.fill(100)(r.nextInt(2)).map(i => (i, i))). var df = ss.createDataFrame(Seq.fill(100)(r.nextInt(2)).map(i => (i, i))).
toDF("feature", "label") toDF("feature", "label")
for (x <- 1 to 252) {
df = df.withColumn(s"feature_${x}", lit(1))
}
val assembler = new VectorAssembler() val assembler = new VectorAssembler()
.setInputCols(df.columns.filter(!_.contains("label"))) .setInputCols(df.columns.filter(!_.contains("label")))
.setOutputCol("features") .setOutputCol("features")
@ -67,5 +71,4 @@ class FeatureSizeValidatingSuite extends FunSuite with PerTest {
xgb.fit(repartitioned) xgb.fit(repartitioned)
} }
} }
} }

View File

@ -946,7 +946,7 @@ class LearnerImpl : public LearnerIO {
common::GlobalRandom().seed(generic_parameters_.seed * kRandSeedMagic + iter); common::GlobalRandom().seed(generic_parameters_.seed * kRandSeedMagic + iter);
} }
this->CheckDataSplitMode(); this->CheckDataSplitMode();
this->ValidateDMatrix(train.get()); this->ValidateDMatrix(train.get(), true);
auto& predt = this->cache_.Cache(train, generic_parameters_.gpu_id); auto& predt = this->cache_.Cache(train, generic_parameters_.gpu_id);
@ -972,7 +972,7 @@ class LearnerImpl : public LearnerIO {
common::GlobalRandom().seed(generic_parameters_.seed * kRandSeedMagic + iter); common::GlobalRandom().seed(generic_parameters_.seed * kRandSeedMagic + iter);
} }
this->CheckDataSplitMode(); this->CheckDataSplitMode();
this->ValidateDMatrix(train.get()); this->ValidateDMatrix(train.get(), true);
this->cache_.Cache(train, generic_parameters_.gpu_id); this->cache_.Cache(train, generic_parameters_.gpu_id);
gbm_->DoBoost(train.get(), in_gpair, &cache_.Entry(train.get())); gbm_->DoBoost(train.get(), in_gpair, &cache_.Entry(train.get()));
@ -994,7 +994,7 @@ class LearnerImpl : public LearnerIO {
for (size_t i = 0; i < data_sets.size(); ++i) { for (size_t i = 0; i < data_sets.size(); ++i) {
std::shared_ptr<DMatrix> m = data_sets[i]; std::shared_ptr<DMatrix> m = data_sets[i];
auto &predt = this->cache_.Cache(m, generic_parameters_.gpu_id); auto &predt = this->cache_.Cache(m, generic_parameters_.gpu_id);
this->ValidateDMatrix(m.get()); this->ValidateDMatrix(m.get(), false);
this->PredictRaw(m.get(), &predt, false); this->PredictRaw(m.get(), &predt, false);
auto &out = output_predictions_.Cache(m, generic_parameters_.gpu_id).predictions; auto &out = output_predictions_.Cache(m, generic_parameters_.gpu_id).predictions;
@ -1079,11 +1079,11 @@ class LearnerImpl : public LearnerIO {
bool training, bool training,
unsigned ntree_limit = 0) const { unsigned ntree_limit = 0) const {
CHECK(gbm_ != nullptr) << "Predict must happen after Load or configuration"; CHECK(gbm_ != nullptr) << "Predict must happen after Load or configuration";
this->ValidateDMatrix(data); this->ValidateDMatrix(data, false);
gbm_->PredictBatch(data, out_preds, training, ntree_limit); gbm_->PredictBatch(data, out_preds, training, ntree_limit);
} }
void ValidateDMatrix(DMatrix* p_fmat) const { void ValidateDMatrix(DMatrix* p_fmat, bool is_training) const {
MetaInfo const& info = p_fmat->Info(); MetaInfo const& info = p_fmat->Info();
info.Validate(generic_parameters_.gpu_id); info.Validate(generic_parameters_.gpu_id);
@ -1092,8 +1092,15 @@ class LearnerImpl : public LearnerIO {
tparam_.dsplit == DataSplitMode::kAuto; tparam_.dsplit == DataSplitMode::kAuto;
}; };
if (row_based_split()) { if (row_based_split()) {
if (is_training) {
CHECK_EQ(learner_model_param_.num_feature, p_fmat->Info().num_col_) CHECK_EQ(learner_model_param_.num_feature, p_fmat->Info().num_col_)
<< "Number of columns does not match number of features in booster."; << "Number of columns does not match number of features in "
"booster.";
} else {
CHECK_GE(learner_model_param_.num_feature, p_fmat->Info().num_col_)
<< "Number of columns does not match number of features in "
"booster.";
}
} }
} }

View File

@ -165,4 +165,8 @@ TEST(CpuPredictor, InplacePredict) {
TestInplacePrediction(x, "cpu_predictor", kRows, kCols, -1); TestInplacePrediction(x, "cpu_predictor", kRows, kCols, -1);
} }
} }
TEST(CpuPredictor, LesserFeatures) {
TestPredictionWithLesserFeatures("cpu_predictor");
}
} // namespace xgboost } // namespace xgboost

View File

@ -160,5 +160,8 @@ TEST(GPUPredictor, MGPU_InplacePredict) { // NOLINT
dmlc::Error); dmlc::Error);
} }
TEST(GpuPredictor, LesserFeatures) {
TestPredictionWithLesserFeatures("gpu_predictor");
}
} // namespace predictor } // namespace predictor
} // namespace xgboost } // namespace xgboost

View File

@ -134,4 +134,44 @@ void TestInplacePrediction(dmlc::any x, std::string predictor,
learner->SetParam("gpu_id", "-1"); learner->SetParam("gpu_id", "-1");
learner->Configure(); learner->Configure();
} }
void TestPredictionWithLesserFeatures(std::string predictor_name) {
size_t constexpr kRows = 256, kTrainCols = 256, kTestCols = 4, kIters = 4;
auto m_train = RandomDataGenerator(kRows, kTrainCols, 0.5).GenerateDMatrix(true);
auto m_test = RandomDataGenerator(kRows, kTestCols, 0.5).GenerateDMatrix(false);
std::unique_ptr<Learner> learner{Learner::Create({m_train})};
for (size_t i = 0; i < kIters; ++i) {
learner->UpdateOneIter(i, m_train);
}
HostDeviceVector<float> prediction;
learner->SetParam("predictor", predictor_name);
learner->Configure();
Json config{Object()};
learner->SaveConfig(&config);
ASSERT_EQ(get<String>(config["learner"]["gradient_booster"]["gbtree_train_param"]["predictor"]), predictor_name);
learner->Predict(m_test, false, &prediction);
ASSERT_EQ(prediction.Size(), kRows);
auto m_invalid = RandomDataGenerator(kRows, kTrainCols + 1, 0.5).GenerateDMatrix(false);
ASSERT_THROW({learner->Predict(m_invalid, false, &prediction);}, dmlc::Error);
#if defined(XGBOOST_USE_CUDA)
HostDeviceVector<float> from_cpu;
learner->SetParam("predictor", "cpu_predictor");
learner->Predict(m_test, false, &from_cpu);
HostDeviceVector<float> from_cuda;
learner->SetParam("predictor", "gpu_predictor");
learner->Predict(m_test, false, &from_cuda);
auto const& h_cpu = from_cpu.ConstHostVector();
auto const& h_gpu = from_cuda.ConstHostVector();
for (size_t i = 0; i < h_cpu.size(); ++i) {
ASSERT_NEAR(h_cpu[i], h_gpu[i], kRtEps);
}
#endif // defined(XGBOOST_USE_CUDA)
}
} // namespace xgboost } // namespace xgboost

View File

@ -59,6 +59,8 @@ void TestTrainingPrediction(size_t rows, size_t bins, std::string tree_method,
void TestInplacePrediction(dmlc::any x, std::string predictor, void TestInplacePrediction(dmlc::any x, std::string predictor,
bst_row_t rows, bst_feature_t cols, bst_row_t rows, bst_feature_t cols,
int32_t device = -1); int32_t device = -1);
void TestPredictionWithLesserFeatures(std::string preditor_name);
} // namespace xgboost } // namespace xgboost
#endif // XGBOOST_TEST_PREDICTOR_H_ #endif // XGBOOST_TEST_PREDICTOR_H_

View File

@ -6,6 +6,7 @@ import xgboost
import subprocess import subprocess
import numpy import numpy
import json import json
import testing as tm
class TestCLI(unittest.TestCase): class TestCLI(unittest.TestCase):
@ -28,22 +29,20 @@ data = {data_path}
eval[test] = {data_path} eval[test] = {data_path}
''' '''
curdir = os.path.normpath(os.path.abspath(os.path.dirname(__file__))) PROJECT_ROOT = tm.PROJECT_ROOT
project_root = os.path.normpath(
os.path.join(curdir, os.path.pardir, os.path.pardir))
def get_exe(self): def get_exe(self):
if platform.system() == 'Windows': if platform.system() == 'Windows':
exe = 'xgboost.exe' exe = 'xgboost.exe'
else: else:
exe = 'xgboost' exe = 'xgboost'
exe = os.path.join(self.project_root, exe) exe = os.path.join(self.PROJECT_ROOT, exe)
assert os.path.exists(exe) assert os.path.exists(exe)
return exe return exe
def test_cli_model(self): def test_cli_model(self):
data_path = "{root}/demo/data/agaricus.txt.train?format=libsvm".format( data_path = "{root}/demo/data/agaricus.txt.train?format=libsvm".format(
root=self.project_root) root=self.PROJECT_ROOT)
exe = self.get_exe() exe = self.get_exe()
seed = 1994 seed = 1994
@ -128,7 +127,7 @@ eval[test] = {data_path}
def test_cli_model_json(self): def test_cli_model_json(self):
exe = self.get_exe() exe = self.get_exe()
data_path = "{root}/demo/data/agaricus.txt.train?format=libsvm".format( data_path = "{root}/demo/data/agaricus.txt.train?format=libsvm".format(
root=self.project_root) root=self.PROJECT_ROOT)
seed = 1994 seed = 1994
with tempfile.TemporaryDirectory() as tmpdir: with tempfile.TemporaryDirectory() as tmpdir:

View File

@ -117,3 +117,18 @@ def test_aft_demo():
# gamma regression is not tested as it requires running a R script first. # gamma regression is not tested as it requires running a R script first.
# aft viz is not tested due to ploting is not controled # aft viz is not tested due to ploting is not controled
# aft tunning is not tested due to extra dependency. # aft tunning is not tested due to extra dependency.
def test_cli_regression_demo():
reg_dir = os.path.join(DEMO_DIR, 'regression')
script = os.path.join(reg_dir, 'mapfeat.py')
cmd = ['python', script]
subprocess.check_call(cmd, cwd=reg_dir)
script = os.path.join(reg_dir, 'mknfold.py')
cmd = ['python', script, 'machine.txt', '1']
subprocess.check_call(cmd, cwd=reg_dir)
exe = os.path.join(tm.PROJECT_ROOT, 'xgboost')
conf = os.path.join(reg_dir, 'machine.conf')
subprocess.check_call([exe, conf], cwd=reg_dir)

View File

@ -216,3 +216,8 @@ dataset_strategy = _dataset_and_weight()
def non_increasing(L, tolerance=1e-4): def non_increasing(L, tolerance=1e-4):
return all((y - x) < tolerance for x, y in zip(L, L[1:])) return all((y - x) < tolerance for x, y in zip(L, L[1:]))
CURDIR = os.path.normpath(os.path.abspath(os.path.dirname(__file__)))
PROJECT_ROOT = os.path.normpath(
os.path.join(CURDIR, os.path.pardir, os.path.pardir))