""" Example of training survival model with Dask on CPU =================================================== """ import os import dask.array as da import dask.dataframe as dd from dask.distributed import Client, LocalCluster from xgboost import dask as dxgb from xgboost.dask import DaskDMatrix def main(client: Client) -> da.Array: # Load an example survival data from CSV into a Dask data frame. # The Veterans' Administration Lung Cancer Trial # The Statistical Analysis of Failure Time Data by Kalbfleisch J. and Prentice R (1980) CURRENT_DIR = os.path.dirname(__file__) df = dd.read_csv( os.path.join(CURRENT_DIR, os.pardir, "data", "veterans_lung_cancer.csv") ) # DaskDMatrix acts like normal DMatrix, works as a proxy for local # DMatrix scatter around workers. # For AFT survival, you'd need to extract the lower and upper bounds for the label # and pass them as arguments to DaskDMatrix. y_lower_bound = df["Survival_label_lower_bound"] y_upper_bound = df["Survival_label_upper_bound"] X = df.drop(["Survival_label_lower_bound", "Survival_label_upper_bound"], axis=1) dtrain = DaskDMatrix( client, X, label_lower_bound=y_lower_bound, label_upper_bound=y_upper_bound ) # Use train method from xgboost.dask instead of xgboost. This # distributed version of train returns a dictionary containing the # resulting booster and evaluation history obtained from # evaluation metrics. params = { "verbosity": 1, "objective": "survival:aft", "eval_metric": "aft-nloglik", "learning_rate": 0.05, "aft_loss_distribution_scale": 1.20, "aft_loss_distribution": "normal", "max_depth": 6, "lambda": 0.01, "alpha": 0.02, } output = dxgb.train( client, params, dtrain, num_boost_round=100, evals=[(dtrain, "train")] ) bst = output["booster"] history = output["history"] # you can pass output directly into `predict` too. prediction = dxgb.predict(client, bst, dtrain) print("Evaluation history: ", history) # Uncomment the following line to save the model to the disk # bst.save_model('survival_model.json') return prediction if __name__ == "__main__": # or use other clusters for scaling with LocalCluster(n_workers=7, threads_per_worker=4) as cluster: with Client(cluster) as client: main(client)