diff --git a/demo/dask/callbacks.py b/demo/dask/callbacks.py new file mode 100644 index 000000000..e3bf5b39d --- /dev/null +++ b/demo/dask/callbacks.py @@ -0,0 +1,86 @@ +"""Example of using callbacks in Dask""" +import numpy as np +import xgboost as xgb +from xgboost.dask import DaskDMatrix +from dask.distributed import Client +from dask.distributed import LocalCluster +from dask_ml.datasets import make_regression +from dask_ml.model_selection import train_test_split + + +def probability_for_going_backward(epoch): + return 0.999 / (1.0 + 0.05 * np.log(1.0 + epoch)) + + +# All callback functions must inherit from TrainingCallback +class CustomEarlyStopping(xgb.callback.TrainingCallback): + """A custom early stopping class where early stopping is determined stochastically. + In the beginning, allow the metric to become worse with a probability of 0.999. + As boosting progresses, the probability should be adjusted downward""" + + def __init__(self, *, validation_set, target_metric, maximize, seed): + self.validation_set = validation_set + self.target_metric = target_metric + self.maximize = maximize + self.seed = seed + self.rng = np.random.default_rng(seed=seed) + if maximize: + self.better = lambda x, y: x > y + else: + self.better = lambda x, y: x < y + + def after_iteration(self, model, epoch, evals_log): + metric_history = evals_log[self.validation_set][self.target_metric] + if len(metric_history) < 2 or self.better( + metric_history[-1], metric_history[-2] + ): + return False # continue training + p = probability_for_going_backward(epoch) + go_backward = self.rng.choice(2, size=(1,), replace=True, p=[1 - p, p]).astype( + np.bool + )[0] + print( + "The validation metric went into the wrong direction. " + + f"Stopping training with probability {1 - p}..." + ) + if go_backward: + return False # continue training + else: + return True # stop training + + +def main(client): + m = 100000 + n = 100 + X, y = make_regression(n_samples=m, n_features=n, chunks=200, random_state=0) + X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) + + dtrain = DaskDMatrix(client, X_train, y_train) + dtest = DaskDMatrix(client, X_test, y_test) + + output = xgb.dask.train( + client, + { + "verbosity": 1, + "tree_method": "hist", + "objective": "reg:squarederror", + "eval_metric": "rmse", + "max_depth": 6, + "learning_rate": 1.0, + }, + dtrain, + num_boost_round=1000, + evals=[(dtrain, "train"), (dtest, "test")], + callbacks=[ + CustomEarlyStopping( + validation_set="test", target_metric="rmse", maximize=False, seed=0 + ) + ], + ) + + +if __name__ == "__main__": + # or use other clusters for scaling + with LocalCluster(n_workers=4, threads_per_worker=1) as cluster: + with Client(cluster) as client: + main(client)