From 303c603c7d7b77006303e71be5d5ef56c53f7152 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Thu, 18 Apr 2024 19:09:30 +0800 Subject: [PATCH] [pyspark] Reuse the collective communicator. (#10198) --- python-package/xgboost/spark/utils.py | 21 ++++++--------------- 1 file changed, 6 insertions(+), 15 deletions(-) diff --git a/python-package/xgboost/spark/utils.py b/python-package/xgboost/spark/utils.py index c0a876419..7dbe290ae 100644 --- a/python-package/xgboost/spark/utils.py +++ b/python-package/xgboost/spark/utils.py @@ -14,7 +14,8 @@ import pyspark from pyspark import BarrierTaskContext, SparkConf, SparkContext, SparkFiles, TaskContext from pyspark.sql.session import SparkSession -from xgboost import Booster, XGBModel, collective +from xgboost import Booster, XGBModel +from xgboost.collective import CommunicatorContext as CCtx from xgboost.tracker import RabitTracker @@ -42,22 +43,12 @@ def _get_default_params_from_func( return filtered_params_dict -class CommunicatorContext: - """A context controlling collective communicator initialization and finalization. - This isn't specificially necessary (note Part 3), but it is more understandable - coding-wise. - - """ +class CommunicatorContext(CCtx): # pylint: disable=too-few-public-methods + """Context with PySpark specific task ID.""" def __init__(self, context: BarrierTaskContext, **args: Any) -> None: - self.args = args - self.args["DMLC_TASK_ID"] = str(context.partitionId()) - - def __enter__(self) -> None: - collective.init(**self.args) - - def __exit__(self, *args: Any) -> None: - collective.finalize() + args["DMLC_TASK_ID"] = str(context.partitionId()) + super().__init__(**args) def _start_tracker(context: BarrierTaskContext, n_workers: int) -> Dict[str, Any]: