[pyspark] Add type hint to basic utilities. (#8375)
This commit is contained in:
parent
1d2f6de573
commit
d0b99bdd95
@ -1,4 +1,3 @@
|
|||||||
# type: ignore
|
|
||||||
"""PySpark XGBoost integration interface
|
"""PySpark XGBoost integration interface
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|||||||
@ -1,11 +1,12 @@
|
|||||||
# type: ignore
|
|
||||||
"""Xgboost pyspark integration submodule for estimator API."""
|
"""Xgboost pyspark integration submodule for estimator API."""
|
||||||
# pylint: disable=too-many-ancestors
|
# pylint: disable=too-many-ancestors
|
||||||
|
from typing import Any, Type
|
||||||
|
|
||||||
from pyspark.ml.param.shared import HasProbabilityCol, HasRawPredictionCol
|
from pyspark.ml.param.shared import HasProbabilityCol, HasRawPredictionCol
|
||||||
|
|
||||||
from xgboost import XGBClassifier, XGBRanker, XGBRegressor
|
from xgboost import XGBClassifier, XGBRanker, XGBRegressor
|
||||||
|
|
||||||
from .core import (
|
from .core import ( # type: ignore
|
||||||
SparkXGBClassifierModel,
|
SparkXGBClassifierModel,
|
||||||
SparkXGBRankerModel,
|
SparkXGBRankerModel,
|
||||||
SparkXGBRegressorModel,
|
SparkXGBRegressorModel,
|
||||||
@ -95,19 +96,19 @@ class SparkXGBRegressor(_SparkXGBEstimator):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs: Any) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.setParams(**kwargs)
|
self.setParams(**kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _xgb_cls(cls):
|
def _xgb_cls(cls) -> Type[XGBRegressor]:
|
||||||
return XGBRegressor
|
return XGBRegressor
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _pyspark_model_cls(cls):
|
def _pyspark_model_cls(cls) -> Type[SparkXGBRegressorModel]:
|
||||||
return SparkXGBRegressorModel
|
return SparkXGBRegressorModel
|
||||||
|
|
||||||
def _validate_params(self):
|
def _validate_params(self) -> None:
|
||||||
super()._validate_params()
|
super()._validate_params()
|
||||||
if self.isDefined(self.qid_col):
|
if self.isDefined(self.qid_col):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -209,7 +210,7 @@ class SparkXGBClassifier(_SparkXGBEstimator, HasProbabilityCol, HasRawPrediction
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs: Any) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# The default 'objective' param value comes from sklearn `XGBClassifier` ctor,
|
# The default 'objective' param value comes from sklearn `XGBClassifier` ctor,
|
||||||
# but in pyspark we will automatically set objective param depending on
|
# but in pyspark we will automatically set objective param depending on
|
||||||
@ -219,14 +220,14 @@ class SparkXGBClassifier(_SparkXGBEstimator, HasProbabilityCol, HasRawPrediction
|
|||||||
self.setParams(**kwargs)
|
self.setParams(**kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _xgb_cls(cls):
|
def _xgb_cls(cls) -> Type[XGBClassifier]:
|
||||||
return XGBClassifier
|
return XGBClassifier
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _pyspark_model_cls(cls):
|
def _pyspark_model_cls(cls) -> Type[SparkXGBClassifierModel]:
|
||||||
return SparkXGBClassifierModel
|
return SparkXGBClassifierModel
|
||||||
|
|
||||||
def _validate_params(self):
|
def _validate_params(self) -> None:
|
||||||
super()._validate_params()
|
super()._validate_params()
|
||||||
if self.isDefined(self.qid_col):
|
if self.isDefined(self.qid_col):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -342,19 +343,19 @@ class SparkXGBRanker(_SparkXGBEstimator):
|
|||||||
>>> model.transform(df_test).show()
|
>>> model.transform(df_test).show()
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs: Any) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.setParams(**kwargs)
|
self.setParams(**kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _xgb_cls(cls):
|
def _xgb_cls(cls) -> Type[XGBRanker]:
|
||||||
return XGBRanker
|
return XGBRanker
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _pyspark_model_cls(cls):
|
def _pyspark_model_cls(cls) -> Type[SparkXGBRankerModel]:
|
||||||
return SparkXGBRankerModel
|
return SparkXGBRankerModel
|
||||||
|
|
||||||
def _validate_params(self):
|
def _validate_params(self) -> None:
|
||||||
super()._validate_params()
|
super()._validate_params()
|
||||||
if not self.isDefined(self.qid_col):
|
if not self.isDefined(self.qid_col):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|||||||
@ -1,4 +1,3 @@
|
|||||||
# type: ignore
|
|
||||||
"""Xgboost pyspark integration submodule for params."""
|
"""Xgboost pyspark integration submodule for params."""
|
||||||
# pylint: disable=too-few-public-methods
|
# pylint: disable=too-few-public-methods
|
||||||
from pyspark.ml.param import TypeConverters
|
from pyspark.ml.param import TypeConverters
|
||||||
@ -12,7 +11,7 @@ class HasArbitraryParamsDict(Params):
|
|||||||
input.
|
input.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
arbitrary_params_dict = Param(
|
arbitrary_params_dict: Param[dict] = Param(
|
||||||
Params._dummy(),
|
Params._dummy(),
|
||||||
"arbitrary_params_dict",
|
"arbitrary_params_dict",
|
||||||
"arbitrary_params_dict This parameter holds all of the additional parameters which are "
|
"arbitrary_params_dict This parameter holds all of the additional parameters which are "
|
||||||
@ -31,6 +30,7 @@ class HasBaseMarginCol(Params):
|
|||||||
Params._dummy(),
|
Params._dummy(),
|
||||||
"base_margin_col",
|
"base_margin_col",
|
||||||
"This stores the name for the column of the base margin",
|
"This stores the name for the column of the base margin",
|
||||||
|
typeConverter=TypeConverters.toString,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -47,7 +47,7 @@ class HasFeaturesCols(Params):
|
|||||||
typeConverter=TypeConverters.toListString,
|
typeConverter=TypeConverters.toListString,
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._setDefault(features_cols=[])
|
self._setDefault(features_cols=[])
|
||||||
|
|
||||||
@ -69,7 +69,7 @@ class HasEnableSparseDataOptim(Params):
|
|||||||
typeConverter=TypeConverters.toBoolean,
|
typeConverter=TypeConverters.toBoolean,
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._setDefault(enable_sparse_data_optim=False)
|
self._setDefault(enable_sparse_data_optim=False)
|
||||||
|
|
||||||
|
|||||||
@ -1,29 +1,30 @@
|
|||||||
# type: ignore
|
|
||||||
"""Xgboost pyspark integration submodule for helper functions."""
|
"""Xgboost pyspark integration submodule for helper functions."""
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
|
from typing import Any, Callable, Dict, List, Set, Type
|
||||||
|
|
||||||
import pyspark
|
import pyspark
|
||||||
|
from pyspark import BarrierTaskContext, SparkContext
|
||||||
from pyspark.sql.session import SparkSession
|
from pyspark.sql.session import SparkSession
|
||||||
from xgboost.tracker import RabitTracker
|
from xgboost.tracker import RabitTracker
|
||||||
|
|
||||||
from xgboost import collective
|
from xgboost import collective
|
||||||
|
|
||||||
|
|
||||||
def get_class_name(cls):
|
def get_class_name(cls: Type) -> str:
|
||||||
"""
|
"""Return the class name."""
|
||||||
Return the class name.
|
|
||||||
"""
|
|
||||||
return f"{cls.__module__}.{cls.__name__}"
|
return f"{cls.__module__}.{cls.__name__}"
|
||||||
|
|
||||||
|
|
||||||
def _get_default_params_from_func(func, unsupported_set):
|
def _get_default_params_from_func(
|
||||||
"""
|
func: Callable, unsupported_set: Set[str]
|
||||||
Returns a dictionary of parameters and their default value of function fn.
|
) -> Dict[str, Any]:
|
||||||
Only the parameters with a default value will be included.
|
"""Returns a dictionary of parameters and their default value of function fn. Only
|
||||||
|
the parameters with a default value will be included.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
sig = inspect.signature(func)
|
sig = inspect.signature(func)
|
||||||
filtered_params_dict = {}
|
filtered_params_dict = {}
|
||||||
@ -38,27 +39,26 @@ def _get_default_params_from_func(func, unsupported_set):
|
|||||||
|
|
||||||
|
|
||||||
class CommunicatorContext:
|
class CommunicatorContext:
|
||||||
"""
|
"""A context controlling collective communicator initialization and finalization.
|
||||||
A context controlling collective communicator initialization and finalization.
|
This isn't specificially necessary (note Part 3), but it is more understandable
|
||||||
This isn't specificially necessary (note Part 3), but it is more understandable coding-wise.
|
coding-wise.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, context, **args):
|
def __init__(self, context: BarrierTaskContext, **args: Any) -> None:
|
||||||
self.args = args
|
self.args = args
|
||||||
self.args["DMLC_TASK_ID"] = str(context.partitionId())
|
self.args["DMLC_TASK_ID"] = str(context.partitionId())
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self) -> None:
|
||||||
collective.init(**self.args)
|
collective.init(**self.args)
|
||||||
|
|
||||||
def __exit__(self, *args):
|
def __exit__(self, *args: Any) -> None:
|
||||||
collective.finalize()
|
collective.finalize()
|
||||||
|
|
||||||
|
|
||||||
def _start_tracker(context, n_workers):
|
def _start_tracker(context: BarrierTaskContext, n_workers: int) -> Dict[str, Any]:
|
||||||
"""
|
"""Start Rabit tracker with n_workers"""
|
||||||
Start Rabit tracker with n_workers
|
env: Dict[str, Any] = {"DMLC_NUM_WORKER": n_workers}
|
||||||
"""
|
|
||||||
env = {"DMLC_NUM_WORKER": n_workers}
|
|
||||||
host = _get_host_ip(context)
|
host = _get_host_ip(context)
|
||||||
rabit_context = RabitTracker(host_ip=host, n_workers=n_workers)
|
rabit_context = RabitTracker(host_ip=host, n_workers=n_workers)
|
||||||
env.update(rabit_context.worker_envs())
|
env.update(rabit_context.worker_envs())
|
||||||
@ -69,27 +69,20 @@ def _start_tracker(context, n_workers):
|
|||||||
return env
|
return env
|
||||||
|
|
||||||
|
|
||||||
def _get_rabit_args(context, n_workers):
|
def _get_rabit_args(context: BarrierTaskContext, n_workers: int) -> Dict[str, Any]:
|
||||||
"""
|
"""Get rabit context arguments to send to each worker."""
|
||||||
Get rabit context arguments to send to each worker.
|
|
||||||
"""
|
|
||||||
# pylint: disable=consider-using-f-string
|
|
||||||
env = _start_tracker(context, n_workers)
|
env = _start_tracker(context, n_workers)
|
||||||
return env
|
return env
|
||||||
|
|
||||||
|
|
||||||
def _get_host_ip(context):
|
def _get_host_ip(context: BarrierTaskContext) -> str:
|
||||||
"""
|
"""Gets the hostIP for Spark. This essentially gets the IP of the first worker."""
|
||||||
Gets the hostIP for Spark. This essentially gets the IP of the first worker.
|
|
||||||
"""
|
|
||||||
task_ip_list = [info.address.split(":")[0] for info in context.getTaskInfos()]
|
task_ip_list = [info.address.split(":")[0] for info in context.getTaskInfos()]
|
||||||
return task_ip_list[0]
|
return task_ip_list[0]
|
||||||
|
|
||||||
|
|
||||||
def _get_args_from_message_list(messages):
|
def _get_args_from_message_list(messages: List[str]) -> Dict[str, Any]:
|
||||||
"""
|
"""A function to send/recieve messages in barrier context mode"""
|
||||||
A function to send/recieve messages in barrier context mode
|
|
||||||
"""
|
|
||||||
output = ""
|
output = ""
|
||||||
for message in messages:
|
for message in messages:
|
||||||
if message != "":
|
if message != "":
|
||||||
@ -98,8 +91,11 @@ def _get_args_from_message_list(messages):
|
|||||||
return json.loads(output)
|
return json.loads(output)
|
||||||
|
|
||||||
|
|
||||||
def _get_spark_session():
|
def _get_spark_session() -> SparkSession:
|
||||||
"""Get or create spark session. Note: This function can only be invoked from driver side."""
|
"""Get or create spark session. Note: This function can only be invoked from driver
|
||||||
|
side.
|
||||||
|
|
||||||
|
"""
|
||||||
if pyspark.TaskContext.get() is not None:
|
if pyspark.TaskContext.get() is not None:
|
||||||
# This is a safety check.
|
# This is a safety check.
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
@ -108,7 +104,7 @@ def _get_spark_session():
|
|||||||
return SparkSession.builder.getOrCreate()
|
return SparkSession.builder.getOrCreate()
|
||||||
|
|
||||||
|
|
||||||
def get_logger(name, level="INFO"):
|
def get_logger(name: str, level: str = "INFO") -> logging.Logger:
|
||||||
"""Gets a logger by name, or creates and configures it for the first time."""
|
"""Gets a logger by name, or creates and configures it for the first time."""
|
||||||
logger = logging.getLogger(name)
|
logger = logging.getLogger(name)
|
||||||
logger.setLevel(level)
|
logger.setLevel(level)
|
||||||
@ -119,7 +115,7 @@ def get_logger(name, level="INFO"):
|
|||||||
return logger
|
return logger
|
||||||
|
|
||||||
|
|
||||||
def _get_max_num_concurrent_tasks(spark_context):
|
def _get_max_num_concurrent_tasks(spark_context: SparkContext) -> int:
|
||||||
"""Gets the current max number of concurrent tasks."""
|
"""Gets the current max number of concurrent tasks."""
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
# spark 3.1 and above has a different API for fetching max concurrent tasks
|
# spark 3.1 and above has a different API for fetching max concurrent tasks
|
||||||
@ -130,13 +126,13 @@ def _get_max_num_concurrent_tasks(spark_context):
|
|||||||
return spark_context._jsc.sc().maxNumConcurrentTasks()
|
return spark_context._jsc.sc().maxNumConcurrentTasks()
|
||||||
|
|
||||||
|
|
||||||
def _is_local(spark_context) -> bool:
|
def _is_local(spark_context: SparkContext) -> bool:
|
||||||
"""Whether it is Spark local mode"""
|
"""Whether it is Spark local mode"""
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
return spark_context._jsc.sc().isLocal()
|
return spark_context._jsc.sc().isLocal()
|
||||||
|
|
||||||
|
|
||||||
def _get_gpu_id(task_context) -> int:
|
def _get_gpu_id(task_context: BarrierTaskContext) -> int:
|
||||||
"""Get the gpu id from the task resources"""
|
"""Get the gpu id from the task resources"""
|
||||||
if task_context is None:
|
if task_context is None:
|
||||||
# This is a safety check.
|
# This is a safety check.
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user