merge latest, Jan 12 2024
This commit is contained in:
@@ -7,6 +7,7 @@ Demo for using and defining callback functions
|
||||
import argparse
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
from matplotlib import pyplot as plt
|
||||
@@ -17,24 +18,26 @@ import xgboost as xgb
|
||||
|
||||
|
||||
class Plotting(xgb.callback.TrainingCallback):
|
||||
"""Plot evaluation result during training. Only for demonstration purpose as it's quite
|
||||
slow to draw.
|
||||
"""Plot evaluation result during training. Only for demonstration purpose as it's
|
||||
quite slow to draw using matplotlib.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, rounds):
|
||||
def __init__(self, rounds: int) -> None:
|
||||
self.fig = plt.figure()
|
||||
self.ax = self.fig.add_subplot(111)
|
||||
self.rounds = rounds
|
||||
self.lines = {}
|
||||
self.lines: Dict[str, plt.Line2D] = {}
|
||||
self.fig.show()
|
||||
self.x = np.linspace(0, self.rounds, self.rounds)
|
||||
plt.ion()
|
||||
|
||||
def _get_key(self, data, metric):
|
||||
def _get_key(self, data: str, metric: str) -> str:
|
||||
return f"{data}-{metric}"
|
||||
|
||||
def after_iteration(self, model, epoch, evals_log):
|
||||
def after_iteration(
|
||||
self, model: xgb.Booster, epoch: int, evals_log: Dict[str, dict]
|
||||
) -> bool:
|
||||
"""Update the plot."""
|
||||
if not self.lines:
|
||||
for data, metric in evals_log.items():
|
||||
@@ -55,7 +58,7 @@ class Plotting(xgb.callback.TrainingCallback):
|
||||
return False
|
||||
|
||||
|
||||
def custom_callback():
|
||||
def custom_callback() -> None:
|
||||
"""Demo for defining a custom callback function that plots evaluation result during
|
||||
training."""
|
||||
X, y = load_breast_cancer(return_X_y=True)
|
||||
@@ -82,19 +85,27 @@ def custom_callback():
|
||||
)
|
||||
|
||||
|
||||
def check_point_callback():
|
||||
# only for demo, set a larger value (like 100) in practice as checkpointing is quite
|
||||
def check_point_callback() -> None:
|
||||
"""Demo for using the checkpoint callback. Custom logic for handling output is
|
||||
usually required and users are encouraged to define their own callback for
|
||||
checkpointing operations. The builtin one can be used as a starting point.
|
||||
|
||||
"""
|
||||
# Only for demo, set a larger value (like 100) in practice as checkpointing is quite
|
||||
# slow.
|
||||
rounds = 2
|
||||
|
||||
def check(as_pickle):
|
||||
def check(as_pickle: bool) -> None:
|
||||
for i in range(0, 10, rounds):
|
||||
if i == 0:
|
||||
continue
|
||||
if as_pickle:
|
||||
path = os.path.join(tmpdir, "model_" + str(i) + ".pkl")
|
||||
else:
|
||||
path = os.path.join(tmpdir, "model_" + str(i) + ".json")
|
||||
path = os.path.join(
|
||||
tmpdir,
|
||||
f"model_{i}.{xgb.callback.TrainingCheckPoint.default_format}",
|
||||
)
|
||||
assert os.path.exists(path)
|
||||
|
||||
X, y = load_breast_cancer(return_X_y=True)
|
||||
|
||||
@@ -78,6 +78,10 @@ def categorical_model(X: pd.DataFrame, y: pd.Series, output_dir: str) -> None:
|
||||
X_train, X_test, y_train, y_test = train_test_split(
|
||||
X, y, random_state=1994, test_size=0.2
|
||||
)
|
||||
# Be aware that the encoding for X_train and X_test are the same here. In practice,
|
||||
# we should try to use an encoder like (sklearn OrdinalEncoder) to obtain the
|
||||
# categorical values.
|
||||
|
||||
# Specify `enable_categorical` to True.
|
||||
clf = xgb.XGBClassifier(
|
||||
**params,
|
||||
|
||||
@@ -58,7 +58,7 @@ def individual_tree() -> None:
|
||||
|
||||
|
||||
def model_slices() -> None:
|
||||
"""Inference with each individual using model slices."""
|
||||
"""Inference with each individual tree using model slices."""
|
||||
X_train, y_train = load_svmlight_file(train)
|
||||
X_test, y_test = load_svmlight_file(test)
|
||||
Xy_train = xgb.QuantileDMatrix(X_train, y_train)
|
||||
|
||||
@@ -9,7 +9,7 @@ https://scikit-learn.org/stable/auto_examples/ensemble/plot_gradient_boosting_qu
|
||||
|
||||
.. note::
|
||||
|
||||
The feature is only supported using the Python package. In addition, quantile
|
||||
The feature is only supported using the Python, R, and C packages. In addition, quantile
|
||||
crossing can happen due to limitation in the algorithm.
|
||||
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user