JSON configuration IO. (#5111)
* Add saving/loading JSON configuration. * Implement Python pickle interface with new IO routines. * Basic tests for training continuation.
This commit is contained in:
@@ -203,7 +203,7 @@ class TestModels(unittest.TestCase):
|
||||
self.assertRaises(ValueError, bst.predict, dm1)
|
||||
bst.predict(dm2) # success
|
||||
|
||||
def test_json_model_io(self):
|
||||
def test_model_json_io(self):
|
||||
X = np.random.random((10, 3))
|
||||
y = np.random.randint(2, size=(10,))
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ import pickle
|
||||
import numpy as np
|
||||
import xgboost as xgb
|
||||
import os
|
||||
import unittest
|
||||
|
||||
|
||||
kRows = 100
|
||||
@@ -14,35 +15,45 @@ def generate_data():
|
||||
return X, y
|
||||
|
||||
|
||||
def test_model_pickling():
|
||||
xgb_params = {
|
||||
'verbosity': 0,
|
||||
'nthread': 1,
|
||||
'tree_method': 'hist'
|
||||
}
|
||||
class TestPickling(unittest.TestCase):
|
||||
def run_model_pickling(self, xgb_params):
|
||||
X, y = generate_data()
|
||||
dtrain = xgb.DMatrix(X, y)
|
||||
bst = xgb.train(xgb_params, dtrain)
|
||||
|
||||
X, y = generate_data()
|
||||
dtrain = xgb.DMatrix(X, y)
|
||||
bst = xgb.train(xgb_params, dtrain)
|
||||
dump_0 = bst.get_dump(dump_format='json')
|
||||
assert dump_0
|
||||
|
||||
dump_0 = bst.get_dump(dump_format='json')
|
||||
assert dump_0
|
||||
filename = 'model.pkl'
|
||||
|
||||
filename = 'model.pkl'
|
||||
with open(filename, 'wb') as fd:
|
||||
pickle.dump(bst, fd)
|
||||
|
||||
with open(filename, 'wb') as fd:
|
||||
pickle.dump(bst, fd)
|
||||
with open(filename, 'rb') as fd:
|
||||
bst = pickle.load(fd)
|
||||
|
||||
with open(filename, 'rb') as fd:
|
||||
bst = pickle.load(fd)
|
||||
with open(filename, 'wb') as fd:
|
||||
pickle.dump(bst, fd)
|
||||
|
||||
with open(filename, 'wb') as fd:
|
||||
pickle.dump(bst, fd)
|
||||
with open(filename, 'rb') as fd:
|
||||
bst = pickle.load(fd)
|
||||
|
||||
with open(filename, 'rb') as fd:
|
||||
bst = pickle.load(fd)
|
||||
assert bst.get_dump(dump_format='json') == dump_0
|
||||
|
||||
assert bst.get_dump(dump_format='json') == dump_0
|
||||
if os.path.exists(filename):
|
||||
os.remove(filename)
|
||||
|
||||
if os.path.exists(filename):
|
||||
os.remove(filename)
|
||||
def test_model_pickling_binary(self):
|
||||
params = {
|
||||
'nthread': 1,
|
||||
'tree_method': 'hist'
|
||||
}
|
||||
self.run_model_pickling(params)
|
||||
|
||||
def test_model_pickling_json(self):
|
||||
params = {
|
||||
'nthread': 1,
|
||||
'tree_method': 'hist',
|
||||
'enable_experimental_json_serialization': True
|
||||
}
|
||||
self.run_model_pickling(params)
|
||||
|
||||
@@ -10,26 +10,35 @@ rng = np.random.RandomState(1337)
|
||||
class TestTrainingContinuation(unittest.TestCase):
|
||||
num_parallel_tree = 3
|
||||
|
||||
xgb_params_01 = {
|
||||
'verbosity': 0,
|
||||
'nthread': 1,
|
||||
}
|
||||
def generate_parameters(self, use_json):
|
||||
xgb_params_01_binary = {
|
||||
'nthread': 1,
|
||||
}
|
||||
|
||||
xgb_params_02 = {
|
||||
'verbosity': 0,
|
||||
'nthread': 1,
|
||||
'num_parallel_tree': num_parallel_tree
|
||||
}
|
||||
xgb_params_02_binary = {
|
||||
'nthread': 1,
|
||||
'num_parallel_tree': self.num_parallel_tree
|
||||
}
|
||||
|
||||
xgb_params_03 = {
|
||||
'verbosity': 0,
|
||||
'nthread': 1,
|
||||
'num_class': 5,
|
||||
'num_parallel_tree': num_parallel_tree
|
||||
}
|
||||
xgb_params_03_binary = {
|
||||
'nthread': 1,
|
||||
'num_class': 5,
|
||||
'num_parallel_tree': self.num_parallel_tree
|
||||
}
|
||||
if use_json:
|
||||
xgb_params_01_binary[
|
||||
'enable_experimental_json_serialization'] = True
|
||||
xgb_params_02_binary[
|
||||
'enable_experimental_json_serialization'] = True
|
||||
xgb_params_03_binary[
|
||||
'enable_experimental_json_serialization'] = True
|
||||
|
||||
@pytest.mark.skipif(**tm.no_sklearn())
|
||||
def test_training_continuation(self):
|
||||
return [
|
||||
xgb_params_01_binary, xgb_params_02_binary, xgb_params_03_binary
|
||||
]
|
||||
|
||||
def run_training_continuation(self, xgb_params_01, xgb_params_02,
|
||||
xgb_params_03):
|
||||
from sklearn.datasets import load_digits
|
||||
from sklearn.metrics import mean_squared_error
|
||||
|
||||
@@ -45,18 +54,18 @@ class TestTrainingContinuation(unittest.TestCase):
|
||||
dtrain_2class = xgb.DMatrix(X_2class, label=y_2class)
|
||||
dtrain_5class = xgb.DMatrix(X_5class, label=y_5class)
|
||||
|
||||
gbdt_01 = xgb.train(self.xgb_params_01, dtrain_2class,
|
||||
gbdt_01 = xgb.train(xgb_params_01, dtrain_2class,
|
||||
num_boost_round=10)
|
||||
ntrees_01 = len(gbdt_01.get_dump())
|
||||
assert ntrees_01 == 10
|
||||
|
||||
gbdt_02 = xgb.train(self.xgb_params_01, dtrain_2class,
|
||||
gbdt_02 = xgb.train(xgb_params_01, dtrain_2class,
|
||||
num_boost_round=0)
|
||||
gbdt_02.save_model('xgb_tc.model')
|
||||
|
||||
gbdt_02a = xgb.train(self.xgb_params_01, dtrain_2class,
|
||||
gbdt_02a = xgb.train(xgb_params_01, dtrain_2class,
|
||||
num_boost_round=10, xgb_model=gbdt_02)
|
||||
gbdt_02b = xgb.train(self.xgb_params_01, dtrain_2class,
|
||||
gbdt_02b = xgb.train(xgb_params_01, dtrain_2class,
|
||||
num_boost_round=10, xgb_model="xgb_tc.model")
|
||||
ntrees_02a = len(gbdt_02a.get_dump())
|
||||
ntrees_02b = len(gbdt_02b.get_dump())
|
||||
@@ -71,13 +80,13 @@ class TestTrainingContinuation(unittest.TestCase):
|
||||
res2 = mean_squared_error(y_2class, gbdt_02b.predict(dtrain_2class))
|
||||
assert res1 == res2
|
||||
|
||||
gbdt_03 = xgb.train(self.xgb_params_01, dtrain_2class,
|
||||
gbdt_03 = xgb.train(xgb_params_01, dtrain_2class,
|
||||
num_boost_round=3)
|
||||
gbdt_03.save_model('xgb_tc.model')
|
||||
|
||||
gbdt_03a = xgb.train(self.xgb_params_01, dtrain_2class,
|
||||
gbdt_03a = xgb.train(xgb_params_01, dtrain_2class,
|
||||
num_boost_round=7, xgb_model=gbdt_03)
|
||||
gbdt_03b = xgb.train(self.xgb_params_01, dtrain_2class,
|
||||
gbdt_03b = xgb.train(xgb_params_01, dtrain_2class,
|
||||
num_boost_round=7, xgb_model="xgb_tc.model")
|
||||
ntrees_03a = len(gbdt_03a.get_dump())
|
||||
ntrees_03b = len(gbdt_03b.get_dump())
|
||||
@@ -88,7 +97,7 @@ class TestTrainingContinuation(unittest.TestCase):
|
||||
res2 = mean_squared_error(y_2class, gbdt_03b.predict(dtrain_2class))
|
||||
assert res1 == res2
|
||||
|
||||
gbdt_04 = xgb.train(self.xgb_params_02, dtrain_2class,
|
||||
gbdt_04 = xgb.train(xgb_params_02, dtrain_2class,
|
||||
num_boost_round=3)
|
||||
assert gbdt_04.best_ntree_limit == (gbdt_04.best_iteration +
|
||||
1) * self.num_parallel_tree
|
||||
@@ -100,7 +109,7 @@ class TestTrainingContinuation(unittest.TestCase):
|
||||
ntree_limit=gbdt_04.best_ntree_limit))
|
||||
assert res1 == res2
|
||||
|
||||
gbdt_04 = xgb.train(self.xgb_params_02, dtrain_2class,
|
||||
gbdt_04 = xgb.train(xgb_params_02, dtrain_2class,
|
||||
num_boost_round=7, xgb_model=gbdt_04)
|
||||
assert gbdt_04.best_ntree_limit == (
|
||||
gbdt_04.best_iteration + 1) * self.num_parallel_tree
|
||||
@@ -112,11 +121,11 @@ class TestTrainingContinuation(unittest.TestCase):
|
||||
ntree_limit=gbdt_04.best_ntree_limit))
|
||||
assert res1 == res2
|
||||
|
||||
gbdt_05 = xgb.train(self.xgb_params_03, dtrain_5class,
|
||||
gbdt_05 = xgb.train(xgb_params_03, dtrain_5class,
|
||||
num_boost_round=7)
|
||||
assert gbdt_05.best_ntree_limit == (
|
||||
gbdt_05.best_iteration + 1) * self.num_parallel_tree
|
||||
gbdt_05 = xgb.train(self.xgb_params_03,
|
||||
gbdt_05 = xgb.train(xgb_params_03,
|
||||
dtrain_5class,
|
||||
num_boost_round=3,
|
||||
xgb_model=gbdt_05)
|
||||
@@ -127,3 +136,32 @@ class TestTrainingContinuation(unittest.TestCase):
|
||||
res2 = gbdt_05.predict(dtrain_5class,
|
||||
ntree_limit=gbdt_05.best_ntree_limit)
|
||||
np.testing.assert_almost_equal(res1, res2)
|
||||
|
||||
@pytest.mark.skipif(**tm.no_sklearn())
|
||||
def test_training_continuation_binary(self):
|
||||
params = self.generate_parameters(False)
|
||||
self.run_training_continuation(params[0], params[1], params[2])
|
||||
|
||||
@pytest.mark.skipif(**tm.no_sklearn())
|
||||
def test_training_continuation_json(self):
|
||||
params = self.generate_parameters(True)
|
||||
for p in params:
|
||||
p['enable_experimental_json_serialization'] = True
|
||||
self.run_training_continuation(params[0], params[1], params[2])
|
||||
|
||||
@pytest.mark.skipif(**tm.no_sklearn())
|
||||
def test_training_continuation_updaters_binary(self):
|
||||
updaters = 'grow_colmaker,prune,refresh'
|
||||
params = self.generate_parameters(False)
|
||||
for p in params:
|
||||
p['updater'] = updaters
|
||||
self.run_training_continuation(params[0], params[1], params[2])
|
||||
|
||||
@pytest.mark.skipif(**tm.no_sklearn())
|
||||
def test_training_continuation_updaters_json(self):
|
||||
# Picked up from R tests.
|
||||
updaters = 'grow_colmaker,prune,refresh'
|
||||
params = self.generate_parameters(True)
|
||||
for p in params:
|
||||
p['updater'] = updaters
|
||||
self.run_training_continuation(params[0], params[1], params[2])
|
||||
|
||||
Reference in New Issue
Block a user