[py] split value histograms

This commit is contained in:
Faron
2016-03-20 15:08:15 +01:00
parent 6691d5c3f4
commit cf607e2448
2 changed files with 63 additions and 1 deletions

View File

@@ -240,3 +240,22 @@ def test_sklearn_nfolds_cv():
cv3 = xgb.cv(params, dm, num_boost_round=10, nfold=nfolds, stratified=True, seed=seed)
assert cv1.shape[0] == cv2.shape[0] and cv2.shape[0] == cv3.shape[0]
assert cv2.iloc[-1, 0] == cv3.iloc[-1, 0]
def test_split_value_histograms():
digits_2class = load_digits(2)
X = digits_2class['data']
y = digits_2class['target']
dm = xgb.DMatrix(X, label=y)
params = {'max_depth': 6, 'eta': 0.01, 'silent': 1, 'objective': 'binary:logistic'}
gbdt = xgb.train(params, dm, num_boost_round=10)
assert gbdt.get_split_value_histogram("not_there", as_pandas=True).shape[0] == 0
assert gbdt.get_split_value_histogram("not_there", as_pandas=False).shape[0] == 0
assert gbdt.get_split_value_histogram("f28", bins=0).shape[0] == 1
assert gbdt.get_split_value_histogram("f28", bins=1).shape[0] == 1
assert gbdt.get_split_value_histogram("f28", bins=2).shape[0] == 2
assert gbdt.get_split_value_histogram("f28", bins=5).shape[0] == 2
assert gbdt.get_split_value_histogram("f28", bins=None).shape[0] == 2