Refactor the CLI. (#5574)
* Enable parameter validation. * Enable JSON. * Catch `dmlc::Error`. * Show help message.
This commit is contained in:
@@ -5,6 +5,7 @@ import platform
|
||||
import xgboost
|
||||
import subprocess
|
||||
import numpy
|
||||
import json
|
||||
|
||||
|
||||
class TestCLI(unittest.TestCase):
|
||||
@@ -27,20 +28,23 @@ data = {data_path}
|
||||
eval[test] = {data_path}
|
||||
'''
|
||||
|
||||
def test_cli_model(self):
|
||||
curdir = os.path.normpath(os.path.abspath(os.path.dirname(__file__)))
|
||||
project_root = os.path.normpath(
|
||||
os.path.join(curdir, os.path.pardir, os.path.pardir))
|
||||
data_path = "{root}/demo/data/agaricus.txt.train?format=libsvm".format(
|
||||
root=project_root)
|
||||
curdir = os.path.normpath(os.path.abspath(os.path.dirname(__file__)))
|
||||
project_root = os.path.normpath(
|
||||
os.path.join(curdir, os.path.pardir, os.path.pardir))
|
||||
|
||||
def get_exe(self):
|
||||
if platform.system() == 'Windows':
|
||||
exe = 'xgboost.exe'
|
||||
else:
|
||||
exe = 'xgboost'
|
||||
exe = os.path.join(project_root, exe)
|
||||
exe = os.path.join(self.project_root, exe)
|
||||
assert os.path.exists(exe)
|
||||
return exe
|
||||
|
||||
def test_cli_model(self):
|
||||
data_path = "{root}/demo/data/agaricus.txt.train?format=libsvm".format(
|
||||
root=self.project_root)
|
||||
exe = self.get_exe()
|
||||
seed = 1994
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
@@ -102,3 +106,48 @@ eval[test] = {data_path}
|
||||
py_model_bin = fd.read()
|
||||
|
||||
assert hash(cli_model_bin) == hash(py_model_bin)
|
||||
|
||||
def test_cli_help(self):
|
||||
exe = self.get_exe()
|
||||
completed = subprocess.run([exe], stdout=subprocess.PIPE)
|
||||
error_msg = completed.stdout.decode('utf-8')
|
||||
ret = completed.returncode
|
||||
assert ret == 1
|
||||
assert error_msg.find('Usage') != -1
|
||||
assert error_msg.find('eval[NAME]') != -1
|
||||
|
||||
completed = subprocess.run([exe, '-V'], stdout=subprocess.PIPE)
|
||||
msg = completed.stdout.decode('utf-8')
|
||||
assert msg.find('XGBoost') != -1
|
||||
v = xgboost.__version__
|
||||
if v.find('SNAPSHOT') != -1:
|
||||
assert msg.split(':')[1].strip() == v.split('-')[0]
|
||||
else:
|
||||
assert msg.split(':')[1].strip() == v
|
||||
|
||||
def test_cli_model_json(self):
|
||||
exe = self.get_exe()
|
||||
data_path = "{root}/demo/data/agaricus.txt.train?format=libsvm".format(
|
||||
root=self.project_root)
|
||||
seed = 1994
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
model_out_cli = os.path.join(
|
||||
tmpdir, 'test_load_cli_model-cli.json')
|
||||
config_path = os.path.join(tmpdir, 'test_load_cli_model.conf')
|
||||
|
||||
train_conf = self.template.format(data_path=data_path,
|
||||
seed=seed,
|
||||
task='train',
|
||||
model_in='NULL',
|
||||
model_out=model_out_cli,
|
||||
test_path='NULL',
|
||||
name_pred='NULL')
|
||||
with open(config_path, 'w') as fd:
|
||||
fd.write(train_conf)
|
||||
|
||||
subprocess.run([exe, config_path])
|
||||
with open(model_out_cli, 'r') as fd:
|
||||
model = json.load(fd)
|
||||
|
||||
assert model['learner']['gradient_booster']['name'] == 'gbtree'
|
||||
|
||||
Reference in New Issue
Block a user