Fix prediction heuristic (#5955)
* Relax check for prediction. * Relax test in spark test. * Add tests in C++.
This commit is contained in:
parent
5879acde9a
commit
75b8c22b0b
@ -26,5 +26,3 @@ data = "machine.txt.train"
|
||||
eval[test] = "machine.txt.test"
|
||||
# The path of test data
|
||||
test:data = "machine.txt.test"
|
||||
|
||||
|
||||
|
||||
1
jvm-packages/.gitignore
vendored
1
jvm-packages/.gitignore
vendored
@ -1,3 +1,2 @@
|
||||
tracker.py
|
||||
build.sh
|
||||
|
||||
|
||||
@ -21,19 +21,23 @@ import org.apache.spark.Partitioner
|
||||
import org.apache.spark.ml.feature.VectorAssembler
|
||||
import org.apache.spark.sql.SparkSession
|
||||
import org.scalatest.FunSuite
|
||||
import org.apache.spark.sql.functions._
|
||||
|
||||
import scala.util.Random
|
||||
|
||||
class FeatureSizeValidatingSuite extends FunSuite with PerTest {
|
||||
|
||||
test("transform throwing exception if feature size of dataset is different with model's") {
|
||||
test("transform throwing exception if feature size of dataset is greater than model's") {
|
||||
val modelPath = getClass.getResource("/model/0.82/model").getPath
|
||||
val model = XGBoostClassificationModel.read.load(modelPath)
|
||||
val r = new Random(0)
|
||||
// 0.82/model was trained with 251 features. and transform will throw exception
|
||||
// if feature size of data is not equal to 251
|
||||
val df = ss.createDataFrame(Seq.fill(100)(r.nextInt(2)).map(i => (i, i))).
|
||||
var df = ss.createDataFrame(Seq.fill(100)(r.nextInt(2)).map(i => (i, i))).
|
||||
toDF("feature", "label")
|
||||
for (x <- 1 to 252) {
|
||||
df = df.withColumn(s"feature_${x}", lit(1))
|
||||
}
|
||||
val assembler = new VectorAssembler()
|
||||
.setInputCols(df.columns.filter(!_.contains("label")))
|
||||
.setOutputCol("features")
|
||||
@ -67,5 +71,4 @@ class FeatureSizeValidatingSuite extends FunSuite with PerTest {
|
||||
xgb.fit(repartitioned)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@ -946,7 +946,7 @@ class LearnerImpl : public LearnerIO {
|
||||
common::GlobalRandom().seed(generic_parameters_.seed * kRandSeedMagic + iter);
|
||||
}
|
||||
this->CheckDataSplitMode();
|
||||
this->ValidateDMatrix(train.get());
|
||||
this->ValidateDMatrix(train.get(), true);
|
||||
|
||||
auto& predt = this->cache_.Cache(train, generic_parameters_.gpu_id);
|
||||
|
||||
@ -972,7 +972,7 @@ class LearnerImpl : public LearnerIO {
|
||||
common::GlobalRandom().seed(generic_parameters_.seed * kRandSeedMagic + iter);
|
||||
}
|
||||
this->CheckDataSplitMode();
|
||||
this->ValidateDMatrix(train.get());
|
||||
this->ValidateDMatrix(train.get(), true);
|
||||
this->cache_.Cache(train, generic_parameters_.gpu_id);
|
||||
|
||||
gbm_->DoBoost(train.get(), in_gpair, &cache_.Entry(train.get()));
|
||||
@ -994,7 +994,7 @@ class LearnerImpl : public LearnerIO {
|
||||
for (size_t i = 0; i < data_sets.size(); ++i) {
|
||||
std::shared_ptr<DMatrix> m = data_sets[i];
|
||||
auto &predt = this->cache_.Cache(m, generic_parameters_.gpu_id);
|
||||
this->ValidateDMatrix(m.get());
|
||||
this->ValidateDMatrix(m.get(), false);
|
||||
this->PredictRaw(m.get(), &predt, false);
|
||||
|
||||
auto &out = output_predictions_.Cache(m, generic_parameters_.gpu_id).predictions;
|
||||
@ -1079,11 +1079,11 @@ class LearnerImpl : public LearnerIO {
|
||||
bool training,
|
||||
unsigned ntree_limit = 0) const {
|
||||
CHECK(gbm_ != nullptr) << "Predict must happen after Load or configuration";
|
||||
this->ValidateDMatrix(data);
|
||||
this->ValidateDMatrix(data, false);
|
||||
gbm_->PredictBatch(data, out_preds, training, ntree_limit);
|
||||
}
|
||||
|
||||
void ValidateDMatrix(DMatrix* p_fmat) const {
|
||||
void ValidateDMatrix(DMatrix* p_fmat, bool is_training) const {
|
||||
MetaInfo const& info = p_fmat->Info();
|
||||
info.Validate(generic_parameters_.gpu_id);
|
||||
|
||||
@ -1092,8 +1092,15 @@ class LearnerImpl : public LearnerIO {
|
||||
tparam_.dsplit == DataSplitMode::kAuto;
|
||||
};
|
||||
if (row_based_split()) {
|
||||
CHECK_EQ(learner_model_param_.num_feature, p_fmat->Info().num_col_)
|
||||
<< "Number of columns does not match number of features in booster.";
|
||||
if (is_training) {
|
||||
CHECK_EQ(learner_model_param_.num_feature, p_fmat->Info().num_col_)
|
||||
<< "Number of columns does not match number of features in "
|
||||
"booster.";
|
||||
} else {
|
||||
CHECK_GE(learner_model_param_.num_feature, p_fmat->Info().num_col_)
|
||||
<< "Number of columns does not match number of features in "
|
||||
"booster.";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -165,4 +165,8 @@ TEST(CpuPredictor, InplacePredict) {
|
||||
TestInplacePrediction(x, "cpu_predictor", kRows, kCols, -1);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(CpuPredictor, LesserFeatures) {
|
||||
TestPredictionWithLesserFeatures("cpu_predictor");
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
@ -160,5 +160,8 @@ TEST(GPUPredictor, MGPU_InplacePredict) { // NOLINT
|
||||
dmlc::Error);
|
||||
}
|
||||
|
||||
TEST(GpuPredictor, LesserFeatures) {
|
||||
TestPredictionWithLesserFeatures("gpu_predictor");
|
||||
}
|
||||
} // namespace predictor
|
||||
} // namespace xgboost
|
||||
|
||||
@ -134,4 +134,44 @@ void TestInplacePrediction(dmlc::any x, std::string predictor,
|
||||
learner->SetParam("gpu_id", "-1");
|
||||
learner->Configure();
|
||||
}
|
||||
|
||||
void TestPredictionWithLesserFeatures(std::string predictor_name) {
|
||||
size_t constexpr kRows = 256, kTrainCols = 256, kTestCols = 4, kIters = 4;
|
||||
auto m_train = RandomDataGenerator(kRows, kTrainCols, 0.5).GenerateDMatrix(true);
|
||||
auto m_test = RandomDataGenerator(kRows, kTestCols, 0.5).GenerateDMatrix(false);
|
||||
std::unique_ptr<Learner> learner{Learner::Create({m_train})};
|
||||
|
||||
for (size_t i = 0; i < kIters; ++i) {
|
||||
learner->UpdateOneIter(i, m_train);
|
||||
}
|
||||
|
||||
HostDeviceVector<float> prediction;
|
||||
learner->SetParam("predictor", predictor_name);
|
||||
learner->Configure();
|
||||
Json config{Object()};
|
||||
learner->SaveConfig(&config);
|
||||
ASSERT_EQ(get<String>(config["learner"]["gradient_booster"]["gbtree_train_param"]["predictor"]), predictor_name);
|
||||
|
||||
learner->Predict(m_test, false, &prediction);
|
||||
ASSERT_EQ(prediction.Size(), kRows);
|
||||
|
||||
auto m_invalid = RandomDataGenerator(kRows, kTrainCols + 1, 0.5).GenerateDMatrix(false);
|
||||
ASSERT_THROW({learner->Predict(m_invalid, false, &prediction);}, dmlc::Error);
|
||||
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
HostDeviceVector<float> from_cpu;
|
||||
learner->SetParam("predictor", "cpu_predictor");
|
||||
learner->Predict(m_test, false, &from_cpu);
|
||||
|
||||
HostDeviceVector<float> from_cuda;
|
||||
learner->SetParam("predictor", "gpu_predictor");
|
||||
learner->Predict(m_test, false, &from_cuda);
|
||||
|
||||
auto const& h_cpu = from_cpu.ConstHostVector();
|
||||
auto const& h_gpu = from_cuda.ConstHostVector();
|
||||
for (size_t i = 0; i < h_cpu.size(); ++i) {
|
||||
ASSERT_NEAR(h_cpu[i], h_gpu[i], kRtEps);
|
||||
}
|
||||
#endif // defined(XGBOOST_USE_CUDA)
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
@ -59,6 +59,8 @@ void TestTrainingPrediction(size_t rows, size_t bins, std::string tree_method,
|
||||
void TestInplacePrediction(dmlc::any x, std::string predictor,
|
||||
bst_row_t rows, bst_feature_t cols,
|
||||
int32_t device = -1);
|
||||
|
||||
void TestPredictionWithLesserFeatures(std::string preditor_name);
|
||||
} // namespace xgboost
|
||||
|
||||
#endif // XGBOOST_TEST_PREDICTOR_H_
|
||||
|
||||
@ -6,6 +6,7 @@ import xgboost
|
||||
import subprocess
|
||||
import numpy
|
||||
import json
|
||||
import testing as tm
|
||||
|
||||
|
||||
class TestCLI(unittest.TestCase):
|
||||
@ -28,22 +29,20 @@ data = {data_path}
|
||||
eval[test] = {data_path}
|
||||
'''
|
||||
|
||||
curdir = os.path.normpath(os.path.abspath(os.path.dirname(__file__)))
|
||||
project_root = os.path.normpath(
|
||||
os.path.join(curdir, os.path.pardir, os.path.pardir))
|
||||
PROJECT_ROOT = tm.PROJECT_ROOT
|
||||
|
||||
def get_exe(self):
|
||||
if platform.system() == 'Windows':
|
||||
exe = 'xgboost.exe'
|
||||
else:
|
||||
exe = 'xgboost'
|
||||
exe = os.path.join(self.project_root, exe)
|
||||
exe = os.path.join(self.PROJECT_ROOT, exe)
|
||||
assert os.path.exists(exe)
|
||||
return exe
|
||||
|
||||
def test_cli_model(self):
|
||||
data_path = "{root}/demo/data/agaricus.txt.train?format=libsvm".format(
|
||||
root=self.project_root)
|
||||
root=self.PROJECT_ROOT)
|
||||
exe = self.get_exe()
|
||||
seed = 1994
|
||||
|
||||
@ -128,7 +127,7 @@ eval[test] = {data_path}
|
||||
def test_cli_model_json(self):
|
||||
exe = self.get_exe()
|
||||
data_path = "{root}/demo/data/agaricus.txt.train?format=libsvm".format(
|
||||
root=self.project_root)
|
||||
root=self.PROJECT_ROOT)
|
||||
seed = 1994
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
|
||||
@ -117,3 +117,18 @@ def test_aft_demo():
|
||||
# gamma regression is not tested as it requires running a R script first.
|
||||
# aft viz is not tested due to ploting is not controled
|
||||
# aft tunning is not tested due to extra dependency.
|
||||
|
||||
|
||||
def test_cli_regression_demo():
|
||||
reg_dir = os.path.join(DEMO_DIR, 'regression')
|
||||
script = os.path.join(reg_dir, 'mapfeat.py')
|
||||
cmd = ['python', script]
|
||||
subprocess.check_call(cmd, cwd=reg_dir)
|
||||
|
||||
script = os.path.join(reg_dir, 'mknfold.py')
|
||||
cmd = ['python', script, 'machine.txt', '1']
|
||||
subprocess.check_call(cmd, cwd=reg_dir)
|
||||
|
||||
exe = os.path.join(tm.PROJECT_ROOT, 'xgboost')
|
||||
conf = os.path.join(reg_dir, 'machine.conf')
|
||||
subprocess.check_call([exe, conf], cwd=reg_dir)
|
||||
|
||||
@ -216,3 +216,8 @@ dataset_strategy = _dataset_and_weight()
|
||||
|
||||
def non_increasing(L, tolerance=1e-4):
|
||||
return all((y - x) < tolerance for x, y in zip(L, L[1:]))
|
||||
|
||||
|
||||
CURDIR = os.path.normpath(os.path.abspath(os.path.dirname(__file__)))
|
||||
PROJECT_ROOT = os.path.normpath(
|
||||
os.path.join(CURDIR, os.path.pardir, os.path.pardir))
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user