[pyspark] Reuse the collective communicator. (#10198)
This commit is contained in:
parent
0aa2600399
commit
303c603c7d
@ -14,7 +14,8 @@ import pyspark
|
||||
from pyspark import BarrierTaskContext, SparkConf, SparkContext, SparkFiles, TaskContext
|
||||
from pyspark.sql.session import SparkSession
|
||||
|
||||
from xgboost import Booster, XGBModel, collective
|
||||
from xgboost import Booster, XGBModel
|
||||
from xgboost.collective import CommunicatorContext as CCtx
|
||||
from xgboost.tracker import RabitTracker
|
||||
|
||||
|
||||
@ -42,22 +43,12 @@ def _get_default_params_from_func(
|
||||
return filtered_params_dict
|
||||
|
||||
|
||||
class CommunicatorContext:
|
||||
"""A context controlling collective communicator initialization and finalization.
|
||||
This isn't specificially necessary (note Part 3), but it is more understandable
|
||||
coding-wise.
|
||||
|
||||
"""
|
||||
class CommunicatorContext(CCtx): # pylint: disable=too-few-public-methods
|
||||
"""Context with PySpark specific task ID."""
|
||||
|
||||
def __init__(self, context: BarrierTaskContext, **args: Any) -> None:
|
||||
self.args = args
|
||||
self.args["DMLC_TASK_ID"] = str(context.partitionId())
|
||||
|
||||
def __enter__(self) -> None:
|
||||
collective.init(**self.args)
|
||||
|
||||
def __exit__(self, *args: Any) -> None:
|
||||
collective.finalize()
|
||||
args["DMLC_TASK_ID"] = str(context.partitionId())
|
||||
super().__init__(**args)
|
||||
|
||||
|
||||
def _start_tracker(context: BarrierTaskContext, n_workers: int) -> Dict[str, Any]:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user