90 lines
2.9 KiB
Python
90 lines
2.9 KiB
Python
"""
|
|
Example of using callbacks with Dask
|
|
====================================
|
|
"""
|
|
import numpy as np
|
|
from dask.distributed import Client, LocalCluster
|
|
from dask_ml.datasets import make_regression
|
|
from dask_ml.model_selection import train_test_split
|
|
from xgboost.dask import DaskDMatrix
|
|
|
|
import xgboost as xgb
|
|
|
|
|
|
def probability_for_going_backward(epoch):
|
|
return 0.999 / (1.0 + 0.05 * np.log(1.0 + epoch))
|
|
|
|
|
|
# All callback functions must inherit from TrainingCallback
|
|
class CustomEarlyStopping(xgb.callback.TrainingCallback):
|
|
"""A custom early stopping class where early stopping is determined stochastically.
|
|
In the beginning, allow the metric to become worse with a probability of 0.999.
|
|
As boosting progresses, the probability should be adjusted downward"""
|
|
|
|
def __init__(self, *, validation_set, target_metric, maximize, seed):
|
|
self.validation_set = validation_set
|
|
self.target_metric = target_metric
|
|
self.maximize = maximize
|
|
self.seed = seed
|
|
self.rng = np.random.default_rng(seed=seed)
|
|
if maximize:
|
|
self.better = lambda x, y: x > y
|
|
else:
|
|
self.better = lambda x, y: x < y
|
|
|
|
def after_iteration(self, model, epoch, evals_log):
|
|
metric_history = evals_log[self.validation_set][self.target_metric]
|
|
if len(metric_history) < 2 or self.better(
|
|
metric_history[-1], metric_history[-2]
|
|
):
|
|
return False # continue training
|
|
p = probability_for_going_backward(epoch)
|
|
go_backward = self.rng.choice(2, size=(1,), replace=True, p=[1 - p, p]).astype(
|
|
np.bool
|
|
)[0]
|
|
print(
|
|
"The validation metric went into the wrong direction. "
|
|
+ f"Stopping training with probability {1 - p}..."
|
|
)
|
|
if go_backward:
|
|
return False # continue training
|
|
else:
|
|
return True # stop training
|
|
|
|
|
|
def main(client):
|
|
m = 100000
|
|
n = 100
|
|
X, y = make_regression(n_samples=m, n_features=n, chunks=200, random_state=0)
|
|
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
|
|
|
|
dtrain = DaskDMatrix(client, X_train, y_train)
|
|
dtest = DaskDMatrix(client, X_test, y_test)
|
|
|
|
output = xgb.dask.train(
|
|
client,
|
|
{
|
|
"verbosity": 1,
|
|
"tree_method": "hist",
|
|
"objective": "reg:squarederror",
|
|
"eval_metric": "rmse",
|
|
"max_depth": 6,
|
|
"learning_rate": 1.0,
|
|
},
|
|
dtrain,
|
|
num_boost_round=1000,
|
|
evals=[(dtrain, "train"), (dtest, "test")],
|
|
callbacks=[
|
|
CustomEarlyStopping(
|
|
validation_set="test", target_metric="rmse", maximize=False, seed=0
|
|
)
|
|
],
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# or use other clusters for scaling
|
|
with LocalCluster(n_workers=4, threads_per_worker=1) as cluster:
|
|
with Client(cluster) as client:
|
|
main(client)
|