diff --git a/tests/python/testing.py b/tests/python/testing.py index 6294964dc..4b2b31e09 100644 --- a/tests/python/testing.py +++ b/tests/python/testing.py @@ -272,6 +272,8 @@ def eval_error_metric(predt, dtrain: xgb.DMatrix): label = dtrain.get_label() r = np.zeros(predt.shape) gt = predt > 0.5 + if predt.size == 0: + return "CustomErr", 0 r[gt] = 1 - label[gt] le = predt <= 0.5 r[le] = label[le]