Fix loading old logit model, helper for converting old pickle. (#5281)

* Fix loading old logit model.
* Add a helper script for converting old pickle file.
* Add version as a model parameter.
* Remove the size check in R test to relax the size constraint.
* Add missing R doc for passing linting. Run devtools.
* Cleanup old model IO logic.
* Test compatibility on CI.
* Make the argument as required.
This commit is contained in:
Jiaming Yuan
2020-02-13 15:28:13 +08:00
committed by GitHub
parent 5ca21f252a
commit 213f4fa45a
30 changed files with 403 additions and 137 deletions

View File

@@ -59,6 +59,29 @@ def generate_regression_model():
reg.save_model(skl_json('reg'))
def generate_logistic_model():
print('Logistic')
y = np.random.randint(0, 2, size=kRows)
assert y.max() == 1 and y.min() == 0
data = xgboost.DMatrix(X, label=y, weight=w)
booster = xgboost.train({'tree_method': 'hist',
'num_parallel_tree': kForests,
'max_depth': kMaxDepth,
'objective': 'binary:logistic'},
num_boost_round=kRounds, dtrain=data)
booster.save_model(booster_bin('logit'))
booster.save_model(booster_json('logit'))
reg = xgboost.XGBClassifier(tree_method='hist',
num_parallel_tree=kForests,
max_depth=kMaxDepth,
n_estimators=kRounds)
reg.fit(X, y, w)
reg.save_model(skl_bin('logit'))
reg.save_model(skl_json('logit'))
def generate_classification_model():
print('Classification')
y = np.random.randint(0, kClasses, size=kRows)
@@ -83,7 +106,7 @@ def generate_classification_model():
def generate_ranking_model():
print('Learning to Rank')
y = np.random.randint(5, size=kRows)
w = np.random.randn(20)
w = np.random.uniform(size=20)
g = np.repeat(50, 20)
data = xgboost.DMatrix(X, y, weight=w)
@@ -119,6 +142,7 @@ if __name__ == '__main__':
os.mkdir(target_dir)
generate_regression_model()
generate_logistic_model()
generate_classification_model()
generate_ranking_model()
write_versions()

View File

@@ -1 +0,0 @@
{'numpy': '1.16.4', 'xgboost': '1.0.0-SNAPSHOT'}

File diff suppressed because one or more lines are too long

View File

@@ -1 +0,0 @@
{"learner":{"attributes":{},"gradient_booster":{"model":{"gbtree_model_param":{"num_trees":"4","size_leaf_vector":"0"},"tree_info":[0,0,0,0],"trees":[{"base_weights":[2.18596185597164094e-09,-3.76773595809936523e-01,4.55630868673324585e-02,1.12075649201869965e-01,-1.93485423922538757e-01],"default_left":[false,false,false,false,false],"id":0,"leaf_child_counts":[1,0,2,0,0],"left_children":[1,-1,3,-1,-1],"loss_changes":[4.20947641134262085e-01,0.00000000000000000e+00,3.69498044252395630e-01,5.97973287105560303e-01,6.13317489624023438e-01],"parents":[2147483647,0,0,2,2],"right_children":[2,-1,4,-1,-1],"split_conditions":[-1.45796775817871094e+00,-5.65160401165485382e-02,8.68250608444213867e-01,1.68113484978675842e-02,-2.90228147059679031e-02],"split_indices":[3,0,1,0,0],"sum_hessian":[2.25207920074462891e+01,1.64538443088531494e+00,2.08754062652587891e+01,1.67469234466552734e+01,4.12848377227783203e+00],"tree_param":{"num_feature":"4","num_nodes":"5","size_leaf_vector":"0"}},{"base_weights":[2.18596185597164094e-09,-3.76773595809936523e-01,4.55630868673324585e-02,1.12075649201869965e-01,-1.93485423922538757e-01],"default_left":[false,false,false,false,false],"id":1,"leaf_child_counts":[1,0,2,0,0],"left_children":[1,-1,3,-1,-1],"loss_changes":[4.20947641134262085e-01,0.00000000000000000e+00,3.69498044252395630e-01,5.97973287105560303e-01,6.13317489624023438e-01],"parents":[2147483647,0,0,2,2],"right_children":[2,-1,4,-1,-1],"split_conditions":[-1.45796775817871094e+00,-5.65160401165485382e-02,8.68250608444213867e-01,1.68113484978675842e-02,-2.90228147059679031e-02],"split_indices":[3,0,1,0,0],"sum_hessian":[2.25207920074462891e+01,1.64538443088531494e+00,2.08754062652587891e+01,1.67469234466552734e+01,4.12848377227783203e+00],"tree_param":{"num_feature":"4","num_nodes":"5","size_leaf_vector":"0"}},{"base_weights":[2.31542762740843955e-09,-1.12662151455879211e-01,3.53309124708175659e-01,-4.52967911958694458e-01,-4.28877249360084534e-02,-1.19008123874664307e-01,4.98231500387191772e-01],"default_left":[false,false,false,false,false,false,false],"id":2,"leaf_child_counts":[0,2,2,0,0,0,0],"left_children":[1,3,5,-1,-1,-1,-1],"loss_changes":[1.03438735008239746e+00,4.48428511619567871e-01,4.89362835884094238e-01,0.00000000000000000e+00,2.74164468050003052e-01,0.00000000000000000e+00,0.00000000000000000e+00],"parents":[2147483647,0,0,1,1,2,2],"right_children":[2,4,6,-1,-1,-1,-1],"split_conditions":[5.69312453269958496e-01,-1.49666213989257812e+00,-3.32068562507629395e-01,-6.79451897740364075e-02,-6.43315911293029785e-03,-1.78512185811996460e-02,7.47347250580787659e-02],"split_indices":[1,1,0,0,0,0,0],"sum_hessian":[2.39866485595703125e+01,1.87036170959472656e+01,5.28303003311157227e+00,2.24795222282409668e+00,1.64556655883789062e+01,1.28239238262176514e+00,4.00063753128051758e+00],"tree_param":{"num_feature":"4","num_nodes":"7","size_leaf_vector":"0"}},{"base_weights":[2.31542762740843955e-09,-1.12662151455879211e-01,3.53309124708175659e-01,-4.52967911958694458e-01,-4.28877249360084534e-02,-1.19008123874664307e-01,4.98231500387191772e-01],"default_left":[false,false,false,false,false,false,false],"id":3,"leaf_child_counts":[0,2,2,0,0,0,0],"left_children":[1,3,5,-1,-1,-1,-1],"loss_changes":[1.03438735008239746e+00,4.48428511619567871e-01,4.89362835884094238e-01,0.00000000000000000e+00,2.74164468050003052e-01,0.00000000000000000e+00,0.00000000000000000e+00],"parents":[2147483647,0,0,1,1,2,2],"right_children":[2,4,6,-1,-1,-1,-1],"split_conditions":[5.69312453269958496e-01,-1.49666213989257812e+00,-3.32068562507629395e-01,-6.79451897740364075e-02,-6.43315911293029785e-03,-1.78512185811996460e-02,7.47347250580787659e-02],"split_indices":[1,1,0,0,0,0,0],"sum_hessian":[2.39866485595703125e+01,1.87036170959472656e+01,5.28303003311157227e+00,2.24795222282409668e+00,1.64556655883789062e+01,1.28239238262176514e+00,4.00063753128051758e+00],"tree_param":{"num_feature":"4","num_nodes":"7","size_leaf_vector":"0"}}]},"name":"gbtree"},"learner_model_param":{"base_score":"0.500000","num_class":"0","num_feature":"4"},"objective":{"lambda_rank_param":{"fix_list_weight":"0","num_pairsample":"1"},"name":"rank:ndcg"}},"version":[1,0,0]}

View File

@@ -1 +0,0 @@
{"learner":{"attributes":{},"gradient_booster":{"model":{"gbtree_model_param":{"num_trees":"4","size_leaf_vector":"0"},"tree_info":[0,0,0,0],"trees":[{"base_weights":[-5.37645816802978516e-01,-4.36891138553619385e-01,-6.70873284339904785e-01,-1.25496864318847656e+00,-4.07270163297653198e-01,-6.88224375247955322e-01,4.64901357889175415e-01],"default_left":[false,false,false,false,false,false,false],"id":0,"leaf_child_counts":[0,2,2,0,0,0,0],"left_children":[1,3,5,-1,-1,-1,-1],"loss_changes":[6.49523925781250000e+00,6.53602600097656250e+00,4.57461547851562500e+00,2.30323791503906250e-01,6.39891815185546875e+00,4.40366363525390625e+00,2.28362298011779785e+00],"parents":[2147483647,0,0,1,1,2,2],"right_children":[2,4,6,-1,-1,-1,-1],"split_conditions":[1.89942225813865662e-01,-1.81951093673706055e+00,2.12066125869750977e+00,-1.88245311379432678e-01,-6.10905252397060394e-02,-1.03233657777309418e-01,6.97352066636085510e-02],"split_indices":[1,0,0,0,0,0,0],"sum_hessian":[5.04713470458984375e+02,2.89816162109375000e+02,2.14897293090820312e+02,8.68150043487548828e+00,2.81134674072265625e+02,2.12051849365234375e+02,2.84543561935424805e+00],"tree_param":{"num_feature":"4","num_nodes":"7","size_leaf_vector":"0"}},{"base_weights":[-5.37645816802978516e-01,-4.36891138553619385e-01,-6.70873284339904785e-01,-1.25496864318847656e+00,-4.07270163297653198e-01,-6.88224375247955322e-01,4.64901357889175415e-01],"default_left":[false,false,false,false,false,false,false],"id":1,"leaf_child_counts":[0,2,2,0,0,0,0],"left_children":[1,3,5,-1,-1,-1,-1],"loss_changes":[6.49523925781250000e+00,6.53602600097656250e+00,4.57461547851562500e+00,2.30323791503906250e-01,6.39891815185546875e+00,4.40366363525390625e+00,2.28362298011779785e+00],"parents":[2147483647,0,0,1,1,2,2],"right_children":[2,4,6,-1,-1,-1,-1],"split_conditions":[1.89942225813865662e-01,-1.81951093673706055e+00,2.12066125869750977e+00,-1.88245311379432678e-01,-6.10905252397060394e-02,-1.03233657777309418e-01,6.97352066636085510e-02],"split_indices":[1,0,0,0,0,0,0],"sum_hessian":[5.04713470458984375e+02,2.89816162109375000e+02,2.14897293090820312e+02,8.68150043487548828e+00,2.81134674072265625e+02,2.12051849365234375e+02,2.84543561935424805e+00],"tree_param":{"num_feature":"4","num_nodes":"7","size_leaf_vector":"0"}},{"base_weights":[-3.77470612525939941e-01,3.31088960170745850e-01,-3.92237067222595215e-01,8.17872881889343262e-01,1.18046358227729797e-01,-3.00728023052215576e-01,-4.70518797636032104e-01],"default_left":[false,false,false,false,false,false,false],"id":2,"leaf_child_counts":[0,2,2,0,0,0,0],"left_children":[1,3,5,-1,-1,-1,-1],"loss_changes":[5.42109680175781250e+00,1.03034389019012451e+00,3.41049194335937500e+00,0.00000000000000000e+00,1.19803142547607422e+00,4.23731803894042969e+00,4.69757843017578125e+00],"parents":[2147483647,0,0,1,1,2,2],"right_children":[2,4,6,-1,-1,-1,-1],"split_conditions":[-2.07929229736328125e+00,-5.09094715118408203e-01,-8.72411578893661499e-02,1.22680939733982086e-01,1.77069548517465591e-02,-4.51092049479484558e-02,-7.05778226256370544e-02],"split_indices":[3,0,3,0,0,0,0],"sum_hessian":[5.04713470458984375e+02,9.86623668670654297e+00,4.94847229003906250e+02,2.13924217224121094e+00,7.72699451446533203e+00,2.30380615234375000e+02,2.64466613769531250e+02],"tree_param":{"num_feature":"4","num_nodes":"7","size_leaf_vector":"0"}},{"base_weights":[-3.77470612525939941e-01,3.31088960170745850e-01,-3.92237067222595215e-01,8.17872881889343262e-01,1.18046358227729797e-01,-3.00728023052215576e-01,-4.70518797636032104e-01],"default_left":[false,false,false,false,false,false,false],"id":3,"leaf_child_counts":[0,2,2,0,0,0,0],"left_children":[1,3,5,-1,-1,-1,-1],"loss_changes":[5.42109680175781250e+00,1.03034389019012451e+00,3.41049194335937500e+00,0.00000000000000000e+00,1.19803142547607422e+00,4.23731803894042969e+00,4.69757843017578125e+00],"parents":[2147483647,0,0,1,1,2,2],"right_children":[2,4,6,-1,-1,-1,-1],"split_conditions":[-2.07929229736328125e+00,-5.09094715118408203e-01,-8.72411578893661499e-02,1.22680939733982086e-01,1.77069548517465591e-02,-4.51092049479484558e-02,-7.05778226256370544e-02],"split_indices":[3,0,3,0,0,0,0],"sum_hessian":[5.04713470458984375e+02,9.86623668670654297e+00,4.94847229003906250e+02,2.13924217224121094e+00,7.72699451446533203e+00,2.30380615234375000e+02,2.64466613769531250e+02],"tree_param":{"num_feature":"4","num_nodes":"7","size_leaf_vector":"0"}}]},"name":"gbtree"},"learner_model_param":{"base_score":"0.500000","num_class":"0","num_feature":"4"},"objective":{"name":"reg:squarederror","reg_loss_param":{"scale_pos_weight":"1"}}},"version":[1,0,0]}

File diff suppressed because one or more lines are too long

View File

@@ -1 +0,0 @@
{"learner":{"attributes":{"scikit_learn":"{\"n_estimators\": 2, \"objective\": \"rank:ndcg\", \"max_depth\": 2, \"learning_rate\": null, \"verbosity\": null, \"booster\": null, \"tree_method\": \"hist\", \"gamma\": null, \"min_child_weight\": null, \"max_delta_step\": null, \"subsample\": null, \"colsample_bytree\": null, \"colsample_bylevel\": null, \"colsample_bynode\": null, \"reg_alpha\": null, \"reg_lambda\": null, \"scale_pos_weight\": null, \"base_score\": null, \"missing\": NaN, \"num_parallel_tree\": 2, \"kwargs\": {}, \"random_state\": null, \"n_jobs\": null, \"monotone_constraints\": null, \"interaction_constraints\": null, \"importance_type\": \"gain\", \"gpu_id\": null, \"type\": \"XGBRanker\"}"},"gradient_booster":{"model":{"gbtree_model_param":{"num_trees":"4","size_leaf_vector":"0"},"tree_info":[0,0,0,0],"trees":[{"base_weights":[2.18596185597164094e-09,-3.76773595809936523e-01,4.55630868673324585e-02,1.12075649201869965e-01,-1.93485423922538757e-01],"default_left":[false,false,false,false,false],"id":0,"leaf_child_counts":[1,0,2,0,0],"left_children":[1,-1,3,-1,-1],"loss_changes":[4.20947641134262085e-01,0.00000000000000000e+00,3.69498044252395630e-01,5.97973287105560303e-01,6.13317489624023438e-01],"parents":[2147483647,0,0,2,2],"right_children":[2,-1,4,-1,-1],"split_conditions":[-1.45796775817871094e+00,-5.65160401165485382e-02,8.68250608444213867e-01,1.68113484978675842e-02,-2.90228147059679031e-02],"split_indices":[3,0,1,0,0],"sum_hessian":[2.25207920074462891e+01,1.64538443088531494e+00,2.08754062652587891e+01,1.67469234466552734e+01,4.12848377227783203e+00],"tree_param":{"num_feature":"4","num_nodes":"5","size_leaf_vector":"0"}},{"base_weights":[2.18596185597164094e-09,-3.76773595809936523e-01,4.55630868673324585e-02,1.12075649201869965e-01,-1.93485423922538757e-01],"default_left":[false,false,false,false,false],"id":1,"leaf_child_counts":[1,0,2,0,0],"left_children":[1,-1,3,-1,-1],"loss_changes":[4.20947641134262085e-01,0.00000000000000000e+00,3.69498044252395630e-01,5.97973287105560303e-01,6.13317489624023438e-01],"parents":[2147483647,0,0,2,2],"right_children":[2,-1,4,-1,-1],"split_conditions":[-1.45796775817871094e+00,-5.65160401165485382e-02,8.68250608444213867e-01,1.68113484978675842e-02,-2.90228147059679031e-02],"split_indices":[3,0,1,0,0],"sum_hessian":[2.25207920074462891e+01,1.64538443088531494e+00,2.08754062652587891e+01,1.67469234466552734e+01,4.12848377227783203e+00],"tree_param":{"num_feature":"4","num_nodes":"5","size_leaf_vector":"0"}},{"base_weights":[2.31542762740843955e-09,-1.12662151455879211e-01,3.53309124708175659e-01,-4.52967911958694458e-01,-4.28877249360084534e-02,-1.19008123874664307e-01,4.98231500387191772e-01],"default_left":[false,false,false,false,false,false,false],"id":2,"leaf_child_counts":[0,2,2,0,0,0,0],"left_children":[1,3,5,-1,-1,-1,-1],"loss_changes":[1.03438735008239746e+00,4.48428511619567871e-01,4.89362835884094238e-01,0.00000000000000000e+00,2.74164468050003052e-01,0.00000000000000000e+00,0.00000000000000000e+00],"parents":[2147483647,0,0,1,1,2,2],"right_children":[2,4,6,-1,-1,-1,-1],"split_conditions":[5.69312453269958496e-01,-1.49666213989257812e+00,-3.32068562507629395e-01,-6.79451897740364075e-02,-6.43315911293029785e-03,-1.78512185811996460e-02,7.47347250580787659e-02],"split_indices":[1,1,0,0,0,0,0],"sum_hessian":[2.39866485595703125e+01,1.87036170959472656e+01,5.28303003311157227e+00,2.24795222282409668e+00,1.64556655883789062e+01,1.28239238262176514e+00,4.00063753128051758e+00],"tree_param":{"num_feature":"4","num_nodes":"7","size_leaf_vector":"0"}},{"base_weights":[2.31542762740843955e-09,-1.12662151455879211e-01,3.53309124708175659e-01,-4.52967911958694458e-01,-4.28877249360084534e-02,-1.19008123874664307e-01,4.98231500387191772e-01],"default_left":[false,false,false,false,false,false,false],"id":3,"leaf_child_counts":[0,2,2,0,0,0,0],"left_children":[1,3,5,-1,-1,-1,-1],"loss_changes":[1.03438735008239746e+00,4.48428511619567871e-01,4.89362835884094238e-01,0.00000000000000000e+00,2.74164468050003052e-01,0.00000000000000000e+00,0.00000000000000000e+00],"parents":[2147483647,0,0,1,1,2,2],"right_children":[2,4,6,-1,-1,-1,-1],"split_conditions":[5.69312453269958496e-01,-1.49666213989257812e+00,-3.32068562507629395e-01,-6.79451897740364075e-02,-6.43315911293029785e-03,-1.78512185811996460e-02,7.47347250580787659e-02],"split_indices":[1,1,0,0,0,0,0],"sum_hessian":[2.39866485595703125e+01,1.87036170959472656e+01,5.28303003311157227e+00,2.24795222282409668e+00,1.64556655883789062e+01,1.28239238262176514e+00,4.00063753128051758e+00],"tree_param":{"num_feature":"4","num_nodes":"7","size_leaf_vector":"0"}}]},"name":"gbtree"},"learner_model_param":{"base_score":"0.500000","num_class":"0","num_feature":"4"},"objective":{"lambda_rank_param":{"fix_list_weight":"0","num_pairsample":"1"},"name":"rank:ndcg"}},"version":[1,0,0]}

File diff suppressed because one or more lines are too long

View File

@@ -39,7 +39,7 @@ class TestBasic(unittest.TestCase):
def test_basic(self):
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
param = {'max_depth': 2, 'eta': 1, 'verbosity': 0,
param = {'max_depth': 2, 'eta': 1,
'objective': 'binary:logistic'}
# specify validations set to watch performance
watchlist = [(dtest, 'eval'), (dtrain, 'train')]

View File

@@ -284,16 +284,31 @@ class TestModels(unittest.TestCase):
self.assertRaises(ValueError, bst.predict, dm1)
bst.predict(dm2) # success
def test_model_binary_io(self):
model_path = 'test_model_binary_io.bin'
parameters = {'tree_method': 'hist', 'booster': 'gbtree',
'scale_pos_weight': '0.5'}
X = np.random.random((10, 3))
y = np.random.random((10,))
dtrain = xgb.DMatrix(X, y)
bst = xgb.train(parameters, dtrain, num_boost_round=2)
bst.save_model(model_path)
bst = xgb.Booster(model_file=model_path)
os.remove(model_path)
config = json.loads(bst.save_config())
assert float(config['learner']['objective'][
'reg_loss_param']['scale_pos_weight']) == 0.5
def test_model_json_io(self):
model_path = './model.json'
model_path = 'test_model_json_io.json'
parameters = {'tree_method': 'hist', 'booster': 'gbtree'}
j_model = json_model(model_path, parameters)
assert isinstance(j_model['learner'], dict)
bst = xgb.Booster(model_file='./model.json')
bst = xgb.Booster(model_file=model_path)
bst.save_model(fname=model_path)
with open('./model.json', 'r') as fd:
with open(model_path, 'r') as fd:
j_model = json.load(fd)
assert isinstance(j_model['learner'], dict)
@@ -302,7 +317,7 @@ class TestModels(unittest.TestCase):
@pytest.mark.skipif(**tm.no_json_schema())
def test_json_schema(self):
import jsonschema
model_path = './model.json'
model_path = 'test_json_schema.json'
path = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
doc = os.path.join(path, 'doc', 'model.schema')

View File

@@ -1,47 +1,130 @@
import xgboost
import os
import generate_models as gm
import json
import zipfile
import pytest
def test_model_compability():
def run_model_param_check(config):
assert config['learner']['learner_model_param']['num_feature'] == str(4)
assert config['learner']['learner_train_param']['booster'] == 'gbtree'
def run_booster_check(booster, name):
config = json.loads(booster.save_config())
run_model_param_check(config)
if name.find('cls') != -1:
assert (len(booster.get_dump()) == gm.kForests * gm.kRounds *
gm.kClasses)
assert float(
config['learner']['learner_model_param']['base_score']) == 0.5
assert config['learner']['learner_train_param'][
'objective'] == 'multi:softmax'
elif name.find('logit') != -1:
assert len(booster.get_dump()) == gm.kForests * gm.kRounds
assert config['learner']['learner_model_param']['num_class'] == str(0)
assert config['learner']['learner_train_param'][
'objective'] == 'binary:logistic'
elif name.find('ltr') != -1:
assert config['learner']['learner_train_param'][
'objective'] == 'rank:ndcg'
else:
assert name.find('reg') != -1
assert len(booster.get_dump()) == gm.kForests * gm.kRounds
assert float(
config['learner']['learner_model_param']['base_score']) == 0.5
assert config['learner']['learner_train_param'][
'objective'] == 'reg:squarederror'
def run_scikit_model_check(name, path):
if name.find('reg') != -1:
reg = xgboost.XGBRegressor()
reg.load_model(path)
config = json.loads(reg.get_booster().save_config())
if name.find('0.90') != -1:
assert config['learner']['learner_train_param'][
'objective'] == 'reg:linear'
else:
assert config['learner']['learner_train_param'][
'objective'] == 'reg:squarederror'
assert (len(reg.get_booster().get_dump()) ==
gm.kRounds * gm.kForests)
run_model_param_check(config)
elif name.find('cls') != -1:
cls = xgboost.XGBClassifier()
cls.load_model(path)
if name.find('0.90') == -1:
assert len(cls.classes_) == gm.kClasses
assert len(cls._le.classes_) == gm.kClasses
assert cls.n_classes_ == gm.kClasses
assert (len(cls.get_booster().get_dump()) ==
gm.kRounds * gm.kForests * gm.kClasses), path
config = json.loads(cls.get_booster().save_config())
assert config['learner']['learner_train_param'][
'objective'] == 'multi:softprob', path
run_model_param_check(config)
elif name.find('ltr') != -1:
ltr = xgboost.XGBRanker()
ltr.load_model(path)
assert (len(ltr.get_booster().get_dump()) ==
gm.kRounds * gm.kForests)
config = json.loads(ltr.get_booster().save_config())
assert config['learner']['learner_train_param'][
'objective'] == 'rank:ndcg'
run_model_param_check(config)
elif name.find('logit') != -1:
logit = xgboost.XGBClassifier()
logit.load_model(path)
assert (len(logit.get_booster().get_dump()) ==
gm.kRounds * gm.kForests)
config = json.loads(logit.get_booster().save_config())
assert config['learner']['learner_train_param'][
'objective'] == 'binary:logistic'
else:
assert False
@pytest.mark.ci
def test_model_compatibility():
'''Test model compatibility, can only be run on CI as others don't
have the credentials.
'''
path = os.path.dirname(os.path.abspath(__file__))
path = os.path.join(path, 'models')
try:
import boto3
import botocore
except ImportError:
pytest.skip(
'Skiping compatibility tests as boto3 is not installed.')
try:
s3_bucket = boto3.resource('s3').Bucket('xgboost-ci-jenkins-artifacts')
zip_path = 'xgboost_model_compatibility_test.zip'
s3_bucket.download_file(zip_path, zip_path)
except botocore.exceptions.NoCredentialsError:
pytest.skip(
'Skiping compatibility tests as running on non-CI environment.')
with zipfile.ZipFile(zip_path, 'r') as z:
z.extractall(path)
models = [
os.path.join(root, f) for root, subdir, files in os.walk(path)
for f in files
if f != 'version'
]
assert len(models) == 12
assert models
for path in models:
name = os.path.basename(path)
if name.startswith('xgboost-'):
booster = xgboost.Booster(model_file=path)
if name.find('cls') != -1:
assert (len(booster.get_dump()) ==
gm.kForests * gm.kRounds * gm.kClasses)
else:
assert len(booster.get_dump()) == gm.kForests * gm.kRounds
run_booster_check(booster, name)
elif name.startswith('xgboost_scikit'):
if name.find('reg') != -1:
reg = xgboost.XGBRegressor()
reg.load_model(path)
assert (len(reg.get_booster().get_dump()) ==
gm.kRounds * gm.kForests)
elif name.find('cls') != -1:
cls = xgboost.XGBClassifier()
cls.load_model(path)
assert len(cls.classes_) == gm.kClasses
assert len(cls._le.classes_) == gm.kClasses
assert cls.n_classes_ == gm.kClasses
assert (len(cls.get_booster().get_dump()) ==
gm.kRounds * gm.kForests * gm.kClasses), path
elif name.find('ltr') != -1:
ltr = xgboost.XGBRanker()
ltr.load_model(path)
assert (len(ltr.get_booster().get_dump()) ==
gm.kRounds * gm.kForests)
else:
assert False
run_scikit_model_check(name, path)
else:
assert False