- Add typehints. - Fixes for pylint. Co-authored-by: Hyunsu Philip Cho <chohyu01@cs.washington.edu>
52 lines
1.8 KiB
Python
52 lines
1.8 KiB
Python
import json
|
|
|
|
import numpy as np
|
|
|
|
import xgboost as xgb
|
|
|
|
rng = np.random.RandomState(1994)
|
|
|
|
|
|
class TestGPUTrainingContinuation:
|
|
def test_training_continuation(self):
|
|
kRows = 64
|
|
kCols = 32
|
|
X = np.random.randn(kRows, kCols)
|
|
y = np.random.randn(kRows)
|
|
dtrain = xgb.DMatrix(X, y)
|
|
params = {'tree_method': 'gpu_hist', 'max_depth': '2',
|
|
'gamma': '0.1', 'alpha': '0.01'}
|
|
bst_0 = xgb.train(params, dtrain, num_boost_round=64)
|
|
dump_0 = bst_0.get_dump(dump_format='json')
|
|
|
|
bst_1 = xgb.train(params, dtrain, num_boost_round=32)
|
|
bst_1 = xgb.train(params, dtrain, num_boost_round=32, xgb_model=bst_1)
|
|
dump_1 = bst_1.get_dump(dump_format='json')
|
|
|
|
def recursive_compare(obj_0, obj_1):
|
|
if isinstance(obj_0, float):
|
|
assert np.isclose(obj_0, obj_1, atol=1e-6)
|
|
elif isinstance(obj_0, str):
|
|
assert obj_0 == obj_1
|
|
elif isinstance(obj_0, int):
|
|
assert obj_0 == obj_1
|
|
elif isinstance(obj_0, dict):
|
|
keys_0 = list(obj_0.keys())
|
|
keys_1 = list(obj_1.keys())
|
|
values_0 = list(obj_0.values())
|
|
values_1 = list(obj_1.values())
|
|
for i in range(len(obj_0.items())):
|
|
assert keys_0[i] == keys_1[i]
|
|
if list(obj_0.keys())[i] != 'missing':
|
|
recursive_compare(values_0[i],
|
|
values_1[i])
|
|
else:
|
|
for i in range(len(obj_0)):
|
|
recursive_compare(obj_0[i], obj_1[i])
|
|
|
|
assert len(dump_0) == len(dump_1)
|
|
for i in range(len(dump_0)):
|
|
obj_0 = json.loads(dump_0[i])
|
|
obj_1 = json.loads(dump_1[i])
|
|
recursive_compare(obj_0, obj_1)
|