[pyspark] Re-work _fit function (#8630)

This commit is contained in:
Bobby Wang 2023-01-04 18:21:57 +08:00 committed by GitHub
parent beefd28471
commit d3ad0524e7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -3,7 +3,8 @@
# pylint: disable=fixme, too-many-ancestors, protected-access, no-member, invalid-name # pylint: disable=fixme, too-many-ancestors, protected-access, no-member, invalid-name
# pylint: disable=too-few-public-methods, too-many-lines, too-many-branches # pylint: disable=too-few-public-methods, too-many-lines, too-many-branches
import json import json
from typing import Iterator, Optional, Tuple from collections import namedtuple
from typing import Iterator, List, Optional, Tuple
import numpy as np import numpy as np
import pandas as pd import pandas as pd
@ -21,6 +22,7 @@ from pyspark.ml.param.shared import (
HasWeightCol, HasWeightCol,
) )
from pyspark.ml.util import MLReadable, MLWritable from pyspark.ml.util import MLReadable, MLWritable
from pyspark.sql import DataFrame
from pyspark.sql.functions import col, countDistinct, pandas_udf, rand, struct from pyspark.sql.functions import col, countDistinct, pandas_udf, rand, struct
from pyspark.sql.types import ( from pyspark.sql.types import (
ArrayType, ArrayType,
@ -471,6 +473,12 @@ def _get_unwrapped_vec_cols(feature_col):
] ]
FeatureProp = namedtuple(
"FeatureProp",
("enable_sparse_data_optim", "has_validation_col", "features_cols_names"),
)
class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable): class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -641,9 +649,9 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
} }
return booster_params, kwargs_params return booster_params, kwargs_params
def _fit(self, dataset): def _prepare_input_columns_and_feature_prop(
# pylint: disable=too-many-statements, too-many-locals self, dataset: DataFrame
self._validate_params() ) -> Tuple[List[str], FeatureProp]:
label_col = col(self.getOrDefault(self.labelCol)).alias(alias.label) label_col = col(self.getOrDefault(self.labelCol)).alias(alias.label)
select_cols = [label_col] select_cols = [label_col]
@ -698,6 +706,18 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
if self.isDefined(self.qid_col) and self.getOrDefault(self.qid_col): if self.isDefined(self.qid_col) and self.getOrDefault(self.qid_col):
select_cols.append(col(self.getOrDefault(self.qid_col)).alias(alias.qid)) select_cols.append(col(self.getOrDefault(self.qid_col)).alias(alias.qid))
feature_prop = FeatureProp(
enable_sparse_data_optim, has_validation_col, features_cols_names
)
return select_cols, feature_prop
def _prepare_input(self, dataset: DataFrame) -> Tuple[DataFrame, FeatureProp]:
"""Prepare the input including column pruning, repartition and so on"""
select_cols, feature_prop = self._prepare_input_columns_and_feature_prop(
dataset
)
dataset = dataset.select(*select_cols) dataset = dataset.select(*select_cols)
num_workers = self.getOrDefault(self.num_workers) num_workers = self.getOrDefault(self.num_workers)
@ -732,11 +752,13 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
# XGBoost requires qid to be sorted for each partition # XGBoost requires qid to be sorted for each partition
dataset = dataset.sortWithinPartitions(alias.qid, ascending=True) dataset = dataset.sortWithinPartitions(alias.qid, ascending=True)
return dataset, feature_prop
def _get_xgb_parameters(self, dataset: DataFrame):
train_params = self._get_distributed_train_params(dataset) train_params = self._get_distributed_train_params(dataset)
booster_params, train_call_kwargs_params = self._get_xgb_train_call_args( booster_params, train_call_kwargs_params = self._get_xgb_train_call_args(
train_params train_params
) )
cpu_per_task = int( cpu_per_task = int(
_get_spark_session().sparkContext.getConf().get("spark.task.cpus", "1") _get_spark_session().sparkContext.getConf().get("spark.task.cpus", "1")
) )
@ -749,9 +771,6 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
"missing": float(self.getOrDefault(self.missing)), "missing": float(self.getOrDefault(self.missing)),
} }
booster_params["nthread"] = cpu_per_task booster_params["nthread"] = cpu_per_task
use_gpu = self.getOrDefault(self.use_gpu)
is_local = _is_local(_get_spark_session().sparkContext)
# Remove the parameters whose value is None # Remove the parameters whose value is None
booster_params = {k: v for k, v in booster_params.items() if v is not None} booster_params = {k: v for k, v in booster_params.items() if v is not None}
@ -760,7 +779,25 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
} }
dmatrix_kwargs = {k: v for k, v in dmatrix_kwargs.items() if v is not None} dmatrix_kwargs = {k: v for k, v in dmatrix_kwargs.items() if v is not None}
use_hist = booster_params.get("tree_method", None) in ("hist", "gpu_hist") return booster_params, train_call_kwargs_params, dmatrix_kwargs
def _fit(self, dataset):
# pylint: disable=too-many-statements, too-many-locals
self._validate_params()
dataset, feature_prop = self._prepare_input(dataset)
(
booster_params,
train_call_kwargs_params,
dmatrix_kwargs,
) = self._get_xgb_parameters(dataset)
use_gpu = self.getOrDefault(self.use_gpu)
is_local = _is_local(_get_spark_session().sparkContext)
num_workers = self.getOrDefault(self.num_workers)
def _train_booster(pandas_df_iter): def _train_booster(pandas_df_iter):
"""Takes in an RDD partition and outputs a booster for that partition after """Takes in an RDD partition and outputs a booster for that partition after
@ -772,6 +809,8 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
context = BarrierTaskContext.get() context = BarrierTaskContext.get()
gpu_id = None gpu_id = None
use_hist = booster_params.get("tree_method", None) in ("hist", "gpu_hist")
if use_gpu: if use_gpu:
gpu_id = context.partitionId() if is_local else _get_gpu_id(context) gpu_id = context.partitionId() if is_local else _get_gpu_id(context)
booster_params["gpu_id"] = gpu_id booster_params["gpu_id"] = gpu_id
@ -814,12 +853,12 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
with CommunicatorContext(context, **_rabit_args): with CommunicatorContext(context, **_rabit_args):
dtrain, dvalid = create_dmatrix_from_partitions( dtrain, dvalid = create_dmatrix_from_partitions(
pandas_df_iter, pandas_df_iter,
features_cols_names, feature_prop.features_cols_names,
gpu_id, gpu_id,
use_qdm, use_qdm,
dmatrix_kwargs, dmatrix_kwargs,
enable_sparse_data_optim=enable_sparse_data_optim, enable_sparse_data_optim=feature_prop.enable_sparse_data_optim,
has_validation_col=has_validation_col, has_validation_col=feature_prop.has_validation_col,
) )
if dvalid is not None: if dvalid is not None:
dval = [(dtrain, "training"), (dvalid, "validation")] dval = [(dtrain, "training"), (dvalid, "validation")]