Fix prediction heuristic (#5955)
* Relax check for prediction. * Relax test in spark test. * Add tests in C++.
This commit is contained in:
@@ -6,6 +6,7 @@ import xgboost
|
||||
import subprocess
|
||||
import numpy
|
||||
import json
|
||||
import testing as tm
|
||||
|
||||
|
||||
class TestCLI(unittest.TestCase):
|
||||
@@ -28,22 +29,20 @@ data = {data_path}
|
||||
eval[test] = {data_path}
|
||||
'''
|
||||
|
||||
curdir = os.path.normpath(os.path.abspath(os.path.dirname(__file__)))
|
||||
project_root = os.path.normpath(
|
||||
os.path.join(curdir, os.path.pardir, os.path.pardir))
|
||||
PROJECT_ROOT = tm.PROJECT_ROOT
|
||||
|
||||
def get_exe(self):
|
||||
if platform.system() == 'Windows':
|
||||
exe = 'xgboost.exe'
|
||||
else:
|
||||
exe = 'xgboost'
|
||||
exe = os.path.join(self.project_root, exe)
|
||||
exe = os.path.join(self.PROJECT_ROOT, exe)
|
||||
assert os.path.exists(exe)
|
||||
return exe
|
||||
|
||||
def test_cli_model(self):
|
||||
data_path = "{root}/demo/data/agaricus.txt.train?format=libsvm".format(
|
||||
root=self.project_root)
|
||||
root=self.PROJECT_ROOT)
|
||||
exe = self.get_exe()
|
||||
seed = 1994
|
||||
|
||||
@@ -128,7 +127,7 @@ eval[test] = {data_path}
|
||||
def test_cli_model_json(self):
|
||||
exe = self.get_exe()
|
||||
data_path = "{root}/demo/data/agaricus.txt.train?format=libsvm".format(
|
||||
root=self.project_root)
|
||||
root=self.PROJECT_ROOT)
|
||||
seed = 1994
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
|
||||
@@ -117,3 +117,18 @@ def test_aft_demo():
|
||||
# gamma regression is not tested as it requires running a R script first.
|
||||
# aft viz is not tested due to ploting is not controled
|
||||
# aft tunning is not tested due to extra dependency.
|
||||
|
||||
|
||||
def test_cli_regression_demo():
|
||||
reg_dir = os.path.join(DEMO_DIR, 'regression')
|
||||
script = os.path.join(reg_dir, 'mapfeat.py')
|
||||
cmd = ['python', script]
|
||||
subprocess.check_call(cmd, cwd=reg_dir)
|
||||
|
||||
script = os.path.join(reg_dir, 'mknfold.py')
|
||||
cmd = ['python', script, 'machine.txt', '1']
|
||||
subprocess.check_call(cmd, cwd=reg_dir)
|
||||
|
||||
exe = os.path.join(tm.PROJECT_ROOT, 'xgboost')
|
||||
conf = os.path.join(reg_dir, 'machine.conf')
|
||||
subprocess.check_call([exe, conf], cwd=reg_dir)
|
||||
|
||||
@@ -216,3 +216,8 @@ dataset_strategy = _dataset_and_weight()
|
||||
|
||||
def non_increasing(L, tolerance=1e-4):
|
||||
return all((y - x) < tolerance for x, y in zip(L, L[1:]))
|
||||
|
||||
|
||||
CURDIR = os.path.normpath(os.path.abspath(os.path.dirname(__file__)))
|
||||
PROJECT_ROOT = os.path.normpath(
|
||||
os.path.join(CURDIR, os.path.pardir, os.path.pardir))
|
||||
|
||||
Reference in New Issue
Block a user