Remove unnecessary dependencies in distributed test (#3132)

This commit is contained in:
Yuan (Terry) Tang 2018-02-24 20:24:34 -05:00 committed by GitHub
parent cf89fa7139
commit 11bfa8584d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,20 +1,17 @@
#!/usr/bin/python #!/usr/bin/python
import numpy as np
import scipy.sparse
import pickle
import xgboost as xgb import xgboost as xgb
# always call this before using distributed module # Always call this before using distributed module
xgb.rabit.init() xgb.rabit.init()
# Load file, file will be automatically sharded in distributed mode. # Load file, file will be automatically sharded in distributed mode.
dtrain = xgb.DMatrix('../../demo/data/agaricus.txt.train') dtrain = xgb.DMatrix('../../demo/data/agaricus.txt.train')
dtest = xgb.DMatrix('../../demo/data/agaricus.txt.test') dtest = xgb.DMatrix('../../demo/data/agaricus.txt.test')
# specify parameters via map, definition are same as c++ version # Specify parameters via map, definition are same as c++ version
param = {'max_depth': 2, 'eta': 1, 'silent': 1, 'objective': 'binary:logistic' } param = {'max_depth': 2, 'eta': 1, 'silent': 1, 'objective': 'binary:logistic' }
# specify validations set to watch performance # Specify validations set to watch performance
watchlist = [(dtest,'eval'), (dtrain,'train')] watchlist = [(dtest,'eval'), (dtrain,'train')]
num_round = 20 num_round = 20
@ -22,7 +19,7 @@ num_round = 20
# Currently, this script only support calling train once for fault recovery purpose. # Currently, this script only support calling train once for fault recovery purpose.
bst = xgb.train(param, dtrain, num_round, watchlist, early_stopping_rounds=2) bst = xgb.train(param, dtrain, num_round, watchlist, early_stopping_rounds=2)
# save the model, only ask process 0 to save the model. # Save the model, only ask process 0 to save the model.
if xgb.rabit.get_rank() == 0: if xgb.rabit.get_rank() == 0:
bst.save_model("test.model") bst.save_model("test.model")
xgb.rabit.tracker_print("Finished training\n") xgb.rabit.tracker_print("Finished training\n")