[pyspark] Add type hint to basic utilities. (#8375)

This commit is contained in:
Jiaming Yuan 2022-10-25 17:26:25 +08:00 committed by GitHub
parent 1d2f6de573
commit d0b99bdd95
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 54 additions and 58 deletions

View File

@ -1,4 +1,3 @@
# type: ignore
"""PySpark XGBoost integration interface """PySpark XGBoost integration interface
""" """

View File

@ -1,11 +1,12 @@
# type: ignore
"""Xgboost pyspark integration submodule for estimator API.""" """Xgboost pyspark integration submodule for estimator API."""
# pylint: disable=too-many-ancestors # pylint: disable=too-many-ancestors
from typing import Any, Type
from pyspark.ml.param.shared import HasProbabilityCol, HasRawPredictionCol from pyspark.ml.param.shared import HasProbabilityCol, HasRawPredictionCol
from xgboost import XGBClassifier, XGBRanker, XGBRegressor from xgboost import XGBClassifier, XGBRanker, XGBRegressor
from .core import ( from .core import ( # type: ignore
SparkXGBClassifierModel, SparkXGBClassifierModel,
SparkXGBRankerModel, SparkXGBRankerModel,
SparkXGBRegressorModel, SparkXGBRegressorModel,
@ -95,19 +96,19 @@ class SparkXGBRegressor(_SparkXGBEstimator):
""" """
def __init__(self, **kwargs): def __init__(self, **kwargs: Any) -> None:
super().__init__() super().__init__()
self.setParams(**kwargs) self.setParams(**kwargs)
@classmethod @classmethod
def _xgb_cls(cls): def _xgb_cls(cls) -> Type[XGBRegressor]:
return XGBRegressor return XGBRegressor
@classmethod @classmethod
def _pyspark_model_cls(cls): def _pyspark_model_cls(cls) -> Type[SparkXGBRegressorModel]:
return SparkXGBRegressorModel return SparkXGBRegressorModel
def _validate_params(self): def _validate_params(self) -> None:
super()._validate_params() super()._validate_params()
if self.isDefined(self.qid_col): if self.isDefined(self.qid_col):
raise ValueError( raise ValueError(
@ -209,7 +210,7 @@ class SparkXGBClassifier(_SparkXGBEstimator, HasProbabilityCol, HasRawPrediction
""" """
def __init__(self, **kwargs): def __init__(self, **kwargs: Any) -> None:
super().__init__() super().__init__()
# The default 'objective' param value comes from sklearn `XGBClassifier` ctor, # The default 'objective' param value comes from sklearn `XGBClassifier` ctor,
# but in pyspark we will automatically set objective param depending on # but in pyspark we will automatically set objective param depending on
@ -219,14 +220,14 @@ class SparkXGBClassifier(_SparkXGBEstimator, HasProbabilityCol, HasRawPrediction
self.setParams(**kwargs) self.setParams(**kwargs)
@classmethod @classmethod
def _xgb_cls(cls): def _xgb_cls(cls) -> Type[XGBClassifier]:
return XGBClassifier return XGBClassifier
@classmethod @classmethod
def _pyspark_model_cls(cls): def _pyspark_model_cls(cls) -> Type[SparkXGBClassifierModel]:
return SparkXGBClassifierModel return SparkXGBClassifierModel
def _validate_params(self): def _validate_params(self) -> None:
super()._validate_params() super()._validate_params()
if self.isDefined(self.qid_col): if self.isDefined(self.qid_col):
raise ValueError( raise ValueError(
@ -342,19 +343,19 @@ class SparkXGBRanker(_SparkXGBEstimator):
>>> model.transform(df_test).show() >>> model.transform(df_test).show()
""" """
def __init__(self, **kwargs): def __init__(self, **kwargs: Any) -> None:
super().__init__() super().__init__()
self.setParams(**kwargs) self.setParams(**kwargs)
@classmethod @classmethod
def _xgb_cls(cls): def _xgb_cls(cls) -> Type[XGBRanker]:
return XGBRanker return XGBRanker
@classmethod @classmethod
def _pyspark_model_cls(cls): def _pyspark_model_cls(cls) -> Type[SparkXGBRankerModel]:
return SparkXGBRankerModel return SparkXGBRankerModel
def _validate_params(self): def _validate_params(self) -> None:
super()._validate_params() super()._validate_params()
if not self.isDefined(self.qid_col): if not self.isDefined(self.qid_col):
raise ValueError( raise ValueError(

View File

@ -1,4 +1,3 @@
# type: ignore
"""Xgboost pyspark integration submodule for params.""" """Xgboost pyspark integration submodule for params."""
# pylint: disable=too-few-public-methods # pylint: disable=too-few-public-methods
from pyspark.ml.param import TypeConverters from pyspark.ml.param import TypeConverters
@ -12,7 +11,7 @@ class HasArbitraryParamsDict(Params):
input. input.
""" """
arbitrary_params_dict = Param( arbitrary_params_dict: Param[dict] = Param(
Params._dummy(), Params._dummy(),
"arbitrary_params_dict", "arbitrary_params_dict",
"arbitrary_params_dict This parameter holds all of the additional parameters which are " "arbitrary_params_dict This parameter holds all of the additional parameters which are "
@ -31,6 +30,7 @@ class HasBaseMarginCol(Params):
Params._dummy(), Params._dummy(),
"base_margin_col", "base_margin_col",
"This stores the name for the column of the base margin", "This stores the name for the column of the base margin",
typeConverter=TypeConverters.toString,
) )
@ -47,7 +47,7 @@ class HasFeaturesCols(Params):
typeConverter=TypeConverters.toListString, typeConverter=TypeConverters.toListString,
) )
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
self._setDefault(features_cols=[]) self._setDefault(features_cols=[])
@ -69,7 +69,7 @@ class HasEnableSparseDataOptim(Params):
typeConverter=TypeConverters.toBoolean, typeConverter=TypeConverters.toBoolean,
) )
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
self._setDefault(enable_sparse_data_optim=False) self._setDefault(enable_sparse_data_optim=False)

View File

@ -1,29 +1,30 @@
# type: ignore
"""Xgboost pyspark integration submodule for helper functions.""" """Xgboost pyspark integration submodule for helper functions."""
import inspect import inspect
import json import json
import logging import logging
import sys import sys
from threading import Thread from threading import Thread
from typing import Any, Callable, Dict, List, Set, Type
import pyspark import pyspark
from pyspark import BarrierTaskContext, SparkContext
from pyspark.sql.session import SparkSession from pyspark.sql.session import SparkSession
from xgboost.tracker import RabitTracker from xgboost.tracker import RabitTracker
from xgboost import collective from xgboost import collective
def get_class_name(cls): def get_class_name(cls: Type) -> str:
""" """Return the class name."""
Return the class name.
"""
return f"{cls.__module__}.{cls.__name__}" return f"{cls.__module__}.{cls.__name__}"
def _get_default_params_from_func(func, unsupported_set): def _get_default_params_from_func(
""" func: Callable, unsupported_set: Set[str]
Returns a dictionary of parameters and their default value of function fn. ) -> Dict[str, Any]:
Only the parameters with a default value will be included. """Returns a dictionary of parameters and their default value of function fn. Only
the parameters with a default value will be included.
""" """
sig = inspect.signature(func) sig = inspect.signature(func)
filtered_params_dict = {} filtered_params_dict = {}
@ -38,27 +39,26 @@ def _get_default_params_from_func(func, unsupported_set):
class CommunicatorContext: class CommunicatorContext:
""" """A context controlling collective communicator initialization and finalization.
A context controlling collective communicator initialization and finalization. This isn't specificially necessary (note Part 3), but it is more understandable
This isn't specificially necessary (note Part 3), but it is more understandable coding-wise. coding-wise.
""" """
def __init__(self, context, **args): def __init__(self, context: BarrierTaskContext, **args: Any) -> None:
self.args = args self.args = args
self.args["DMLC_TASK_ID"] = str(context.partitionId()) self.args["DMLC_TASK_ID"] = str(context.partitionId())
def __enter__(self): def __enter__(self) -> None:
collective.init(**self.args) collective.init(**self.args)
def __exit__(self, *args): def __exit__(self, *args: Any) -> None:
collective.finalize() collective.finalize()
def _start_tracker(context, n_workers): def _start_tracker(context: BarrierTaskContext, n_workers: int) -> Dict[str, Any]:
""" """Start Rabit tracker with n_workers"""
Start Rabit tracker with n_workers env: Dict[str, Any] = {"DMLC_NUM_WORKER": n_workers}
"""
env = {"DMLC_NUM_WORKER": n_workers}
host = _get_host_ip(context) host = _get_host_ip(context)
rabit_context = RabitTracker(host_ip=host, n_workers=n_workers) rabit_context = RabitTracker(host_ip=host, n_workers=n_workers)
env.update(rabit_context.worker_envs()) env.update(rabit_context.worker_envs())
@ -69,27 +69,20 @@ def _start_tracker(context, n_workers):
return env return env
def _get_rabit_args(context, n_workers): def _get_rabit_args(context: BarrierTaskContext, n_workers: int) -> Dict[str, Any]:
""" """Get rabit context arguments to send to each worker."""
Get rabit context arguments to send to each worker.
"""
# pylint: disable=consider-using-f-string
env = _start_tracker(context, n_workers) env = _start_tracker(context, n_workers)
return env return env
def _get_host_ip(context): def _get_host_ip(context: BarrierTaskContext) -> str:
""" """Gets the hostIP for Spark. This essentially gets the IP of the first worker."""
Gets the hostIP for Spark. This essentially gets the IP of the first worker.
"""
task_ip_list = [info.address.split(":")[0] for info in context.getTaskInfos()] task_ip_list = [info.address.split(":")[0] for info in context.getTaskInfos()]
return task_ip_list[0] return task_ip_list[0]
def _get_args_from_message_list(messages): def _get_args_from_message_list(messages: List[str]) -> Dict[str, Any]:
""" """A function to send/recieve messages in barrier context mode"""
A function to send/recieve messages in barrier context mode
"""
output = "" output = ""
for message in messages: for message in messages:
if message != "": if message != "":
@ -98,8 +91,11 @@ def _get_args_from_message_list(messages):
return json.loads(output) return json.loads(output)
def _get_spark_session(): def _get_spark_session() -> SparkSession:
"""Get or create spark session. Note: This function can only be invoked from driver side.""" """Get or create spark session. Note: This function can only be invoked from driver
side.
"""
if pyspark.TaskContext.get() is not None: if pyspark.TaskContext.get() is not None:
# This is a safety check. # This is a safety check.
raise RuntimeError( raise RuntimeError(
@ -108,7 +104,7 @@ def _get_spark_session():
return SparkSession.builder.getOrCreate() return SparkSession.builder.getOrCreate()
def get_logger(name, level="INFO"): def get_logger(name: str, level: str = "INFO") -> logging.Logger:
"""Gets a logger by name, or creates and configures it for the first time.""" """Gets a logger by name, or creates and configures it for the first time."""
logger = logging.getLogger(name) logger = logging.getLogger(name)
logger.setLevel(level) logger.setLevel(level)
@ -119,7 +115,7 @@ def get_logger(name, level="INFO"):
return logger return logger
def _get_max_num_concurrent_tasks(spark_context): def _get_max_num_concurrent_tasks(spark_context: SparkContext) -> int:
"""Gets the current max number of concurrent tasks.""" """Gets the current max number of concurrent tasks."""
# pylint: disable=protected-access # pylint: disable=protected-access
# spark 3.1 and above has a different API for fetching max concurrent tasks # spark 3.1 and above has a different API for fetching max concurrent tasks
@ -130,13 +126,13 @@ def _get_max_num_concurrent_tasks(spark_context):
return spark_context._jsc.sc().maxNumConcurrentTasks() return spark_context._jsc.sc().maxNumConcurrentTasks()
def _is_local(spark_context) -> bool: def _is_local(spark_context: SparkContext) -> bool:
"""Whether it is Spark local mode""" """Whether it is Spark local mode"""
# pylint: disable=protected-access # pylint: disable=protected-access
return spark_context._jsc.sc().isLocal() return spark_context._jsc.sc().isLocal()
def _get_gpu_id(task_context) -> int: def _get_gpu_id(task_context: BarrierTaskContext) -> int:
"""Get the gpu id from the task resources""" """Get the gpu id from the task resources"""
if task_context is None: if task_context is None:
# This is a safety check. # This is a safety check.