Remove unnecessary dependencies in distributed test (#3132)
This commit is contained in:
parent
cf89fa7139
commit
11bfa8584d
@ -1,20 +1,17 @@
|
|||||||
#!/usr/bin/python
|
#!/usr/bin/python
|
||||||
import numpy as np
|
|
||||||
import scipy.sparse
|
|
||||||
import pickle
|
|
||||||
import xgboost as xgb
|
import xgboost as xgb
|
||||||
|
|
||||||
# always call this before using distributed module
|
# Always call this before using distributed module
|
||||||
xgb.rabit.init()
|
xgb.rabit.init()
|
||||||
|
|
||||||
# Load file, file will be automatically sharded in distributed mode.
|
# Load file, file will be automatically sharded in distributed mode.
|
||||||
dtrain = xgb.DMatrix('../../demo/data/agaricus.txt.train')
|
dtrain = xgb.DMatrix('../../demo/data/agaricus.txt.train')
|
||||||
dtest = xgb.DMatrix('../../demo/data/agaricus.txt.test')
|
dtest = xgb.DMatrix('../../demo/data/agaricus.txt.test')
|
||||||
|
|
||||||
# specify parameters via map, definition are same as c++ version
|
# Specify parameters via map, definition are same as c++ version
|
||||||
param = {'max_depth': 2, 'eta': 1, 'silent': 1, 'objective': 'binary:logistic' }
|
param = {'max_depth': 2, 'eta': 1, 'silent': 1, 'objective': 'binary:logistic' }
|
||||||
|
|
||||||
# specify validations set to watch performance
|
# Specify validations set to watch performance
|
||||||
watchlist = [(dtest,'eval'), (dtrain,'train')]
|
watchlist = [(dtest,'eval'), (dtrain,'train')]
|
||||||
num_round = 20
|
num_round = 20
|
||||||
|
|
||||||
@ -22,7 +19,7 @@ num_round = 20
|
|||||||
# Currently, this script only support calling train once for fault recovery purpose.
|
# Currently, this script only support calling train once for fault recovery purpose.
|
||||||
bst = xgb.train(param, dtrain, num_round, watchlist, early_stopping_rounds=2)
|
bst = xgb.train(param, dtrain, num_round, watchlist, early_stopping_rounds=2)
|
||||||
|
|
||||||
# save the model, only ask process 0 to save the model.
|
# Save the model, only ask process 0 to save the model.
|
||||||
if xgb.rabit.get_rank() == 0:
|
if xgb.rabit.get_rank() == 0:
|
||||||
bst.save_model("test.model")
|
bst.save_model("test.model")
|
||||||
xgb.rabit.tracker_print("Finished training\n")
|
xgb.rabit.tracker_print("Finished training\n")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user