[py] split value histograms
This commit is contained in:
@@ -240,3 +240,22 @@ def test_sklearn_nfolds_cv():
|
||||
cv3 = xgb.cv(params, dm, num_boost_round=10, nfold=nfolds, stratified=True, seed=seed)
|
||||
assert cv1.shape[0] == cv2.shape[0] and cv2.shape[0] == cv3.shape[0]
|
||||
assert cv2.iloc[-1, 0] == cv3.iloc[-1, 0]
|
||||
|
||||
|
||||
def test_split_value_histograms():
|
||||
digits_2class = load_digits(2)
|
||||
|
||||
X = digits_2class['data']
|
||||
y = digits_2class['target']
|
||||
|
||||
dm = xgb.DMatrix(X, label=y)
|
||||
params = {'max_depth': 6, 'eta': 0.01, 'silent': 1, 'objective': 'binary:logistic'}
|
||||
|
||||
gbdt = xgb.train(params, dm, num_boost_round=10)
|
||||
assert gbdt.get_split_value_histogram("not_there", as_pandas=True).shape[0] == 0
|
||||
assert gbdt.get_split_value_histogram("not_there", as_pandas=False).shape[0] == 0
|
||||
assert gbdt.get_split_value_histogram("f28", bins=0).shape[0] == 1
|
||||
assert gbdt.get_split_value_histogram("f28", bins=1).shape[0] == 1
|
||||
assert gbdt.get_split_value_histogram("f28", bins=2).shape[0] == 2
|
||||
assert gbdt.get_split_value_histogram("f28", bins=5).shape[0] == 2
|
||||
assert gbdt.get_split_value_histogram("f28", bins=None).shape[0] == 2
|
||||
|
||||
Reference in New Issue
Block a user