Rewrite approx (#7214)

This PR rewrites the approx tree method to use codebase from hist for better performance and code sharing.

The rewrite has many benefits:
- Support for both `max_leaves` and `max_depth`.
- Support for `grow_policy`.
- Support for mono constraint.
- Support for feature weights.
- Support for easier bin configuration (`max_bin`).
- Support for categorical data.
- Faster performance for most of the datasets. (many times faster)
- Support for prediction cache.
- Significantly better performance for external memory.
- Unites the code base between approx and hist.
This commit is contained in:
Jiaming Yuan
2022-01-10 21:15:05 +08:00
committed by GitHub
parent ed95e77752
commit 001503186c
22 changed files with 635 additions and 264 deletions

View File

@@ -1031,10 +1031,10 @@ def test_pandas_input():
np.array([0, 1]))
def run_feature_weights(X, y, fw, model=xgb.XGBRegressor):
def run_feature_weights(X, y, fw, tree_method, model=xgb.XGBRegressor):
with tempfile.TemporaryDirectory() as tmpdir:
colsample_bynode = 0.5
reg = model(tree_method='hist', colsample_bynode=colsample_bynode)
reg = model(tree_method=tree_method, colsample_bynode=colsample_bynode)
reg.fit(X, y, feature_weights=fw)
model_path = os.path.join(tmpdir, 'model.json')
@@ -1069,7 +1069,8 @@ def run_feature_weights(X, y, fw, model=xgb.XGBRegressor):
return w
def test_feature_weights():
@pytest.mark.parametrize("tree_method", ["approx", "hist"])
def test_feature_weights(tree_method):
kRows = 512
kCols = 64
X = rng.randn(kRows, kCols)
@@ -1078,12 +1079,12 @@ def test_feature_weights():
fw = np.ones(shape=(kCols,))
for i in range(kCols):
fw[i] *= float(i)
poly_increasing = run_feature_weights(X, y, fw, xgb.XGBRegressor)
poly_increasing = run_feature_weights(X, y, fw, tree_method, xgb.XGBRegressor)
fw = np.ones(shape=(kCols,))
for i in range(kCols):
fw[i] *= float(kCols - i)
poly_decreasing = run_feature_weights(X, y, fw, xgb.XGBRegressor)
poly_decreasing = run_feature_weights(X, y, fw, tree_method, xgb.XGBRegressor)
# Approxmated test, this is dependent on the implementation of random
# number generator in std library.