[pyspark] Reuse the collective communicator. (#10198)
This commit is contained in:
parent
0aa2600399
commit
303c603c7d
@ -14,7 +14,8 @@ import pyspark
|
|||||||
from pyspark import BarrierTaskContext, SparkConf, SparkContext, SparkFiles, TaskContext
|
from pyspark import BarrierTaskContext, SparkConf, SparkContext, SparkFiles, TaskContext
|
||||||
from pyspark.sql.session import SparkSession
|
from pyspark.sql.session import SparkSession
|
||||||
|
|
||||||
from xgboost import Booster, XGBModel, collective
|
from xgboost import Booster, XGBModel
|
||||||
|
from xgboost.collective import CommunicatorContext as CCtx
|
||||||
from xgboost.tracker import RabitTracker
|
from xgboost.tracker import RabitTracker
|
||||||
|
|
||||||
|
|
||||||
@ -42,22 +43,12 @@ def _get_default_params_from_func(
|
|||||||
return filtered_params_dict
|
return filtered_params_dict
|
||||||
|
|
||||||
|
|
||||||
class CommunicatorContext:
|
class CommunicatorContext(CCtx): # pylint: disable=too-few-public-methods
|
||||||
"""A context controlling collective communicator initialization and finalization.
|
"""Context with PySpark specific task ID."""
|
||||||
This isn't specificially necessary (note Part 3), but it is more understandable
|
|
||||||
coding-wise.
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, context: BarrierTaskContext, **args: Any) -> None:
|
def __init__(self, context: BarrierTaskContext, **args: Any) -> None:
|
||||||
self.args = args
|
args["DMLC_TASK_ID"] = str(context.partitionId())
|
||||||
self.args["DMLC_TASK_ID"] = str(context.partitionId())
|
super().__init__(**args)
|
||||||
|
|
||||||
def __enter__(self) -> None:
|
|
||||||
collective.init(**self.args)
|
|
||||||
|
|
||||||
def __exit__(self, *args: Any) -> None:
|
|
||||||
collective.finalize()
|
|
||||||
|
|
||||||
|
|
||||||
def _start_tracker(context: BarrierTaskContext, n_workers: int) -> Dict[str, Any]:
|
def _start_tracker(context: BarrierTaskContext, n_workers: int) -> Dict[str, Any]:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user