[pyspark] Re-work _fit function (#8630)

This commit is contained in:
Bobby Wang 2023-01-04 18:21:57 +08:00 committed by GitHub
parent beefd28471
commit d3ad0524e7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -3,7 +3,8 @@
# pylint: disable=fixme, too-many-ancestors, protected-access, no-member, invalid-name
# pylint: disable=too-few-public-methods, too-many-lines, too-many-branches
import json
from typing import Iterator, Optional, Tuple
from collections import namedtuple
from typing import Iterator, List, Optional, Tuple
import numpy as np
import pandas as pd
@ -21,6 +22,7 @@ from pyspark.ml.param.shared import (
HasWeightCol,
)
from pyspark.ml.util import MLReadable, MLWritable
from pyspark.sql import DataFrame
from pyspark.sql.functions import col, countDistinct, pandas_udf, rand, struct
from pyspark.sql.types import (
ArrayType,
@ -471,6 +473,12 @@ def _get_unwrapped_vec_cols(feature_col):
]
FeatureProp = namedtuple(
"FeatureProp",
("enable_sparse_data_optim", "has_validation_col", "features_cols_names"),
)
class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
def __init__(self):
super().__init__()
@ -641,9 +649,9 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
}
return booster_params, kwargs_params
def _fit(self, dataset):
# pylint: disable=too-many-statements, too-many-locals
self._validate_params()
def _prepare_input_columns_and_feature_prop(
self, dataset: DataFrame
) -> Tuple[List[str], FeatureProp]:
label_col = col(self.getOrDefault(self.labelCol)).alias(alias.label)
select_cols = [label_col]
@ -698,6 +706,18 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
if self.isDefined(self.qid_col) and self.getOrDefault(self.qid_col):
select_cols.append(col(self.getOrDefault(self.qid_col)).alias(alias.qid))
feature_prop = FeatureProp(
enable_sparse_data_optim, has_validation_col, features_cols_names
)
return select_cols, feature_prop
def _prepare_input(self, dataset: DataFrame) -> Tuple[DataFrame, FeatureProp]:
"""Prepare the input including column pruning, repartition and so on"""
select_cols, feature_prop = self._prepare_input_columns_and_feature_prop(
dataset
)
dataset = dataset.select(*select_cols)
num_workers = self.getOrDefault(self.num_workers)
@ -732,11 +752,13 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
# XGBoost requires qid to be sorted for each partition
dataset = dataset.sortWithinPartitions(alias.qid, ascending=True)
return dataset, feature_prop
def _get_xgb_parameters(self, dataset: DataFrame):
train_params = self._get_distributed_train_params(dataset)
booster_params, train_call_kwargs_params = self._get_xgb_train_call_args(
train_params
)
cpu_per_task = int(
_get_spark_session().sparkContext.getConf().get("spark.task.cpus", "1")
)
@ -749,9 +771,6 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
"missing": float(self.getOrDefault(self.missing)),
}
booster_params["nthread"] = cpu_per_task
use_gpu = self.getOrDefault(self.use_gpu)
is_local = _is_local(_get_spark_session().sparkContext)
# Remove the parameters whose value is None
booster_params = {k: v for k, v in booster_params.items() if v is not None}
@ -760,7 +779,25 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
}
dmatrix_kwargs = {k: v for k, v in dmatrix_kwargs.items() if v is not None}
use_hist = booster_params.get("tree_method", None) in ("hist", "gpu_hist")
return booster_params, train_call_kwargs_params, dmatrix_kwargs
def _fit(self, dataset):
# pylint: disable=too-many-statements, too-many-locals
self._validate_params()
dataset, feature_prop = self._prepare_input(dataset)
(
booster_params,
train_call_kwargs_params,
dmatrix_kwargs,
) = self._get_xgb_parameters(dataset)
use_gpu = self.getOrDefault(self.use_gpu)
is_local = _is_local(_get_spark_session().sparkContext)
num_workers = self.getOrDefault(self.num_workers)
def _train_booster(pandas_df_iter):
"""Takes in an RDD partition and outputs a booster for that partition after
@ -772,6 +809,8 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
context = BarrierTaskContext.get()
gpu_id = None
use_hist = booster_params.get("tree_method", None) in ("hist", "gpu_hist")
if use_gpu:
gpu_id = context.partitionId() if is_local else _get_gpu_id(context)
booster_params["gpu_id"] = gpu_id
@ -814,12 +853,12 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
with CommunicatorContext(context, **_rabit_args):
dtrain, dvalid = create_dmatrix_from_partitions(
pandas_df_iter,
features_cols_names,
feature_prop.features_cols_names,
gpu_id,
use_qdm,
dmatrix_kwargs,
enable_sparse_data_optim=enable_sparse_data_optim,
has_validation_col=has_validation_col,
enable_sparse_data_optim=feature_prop.enable_sparse_data_optim,
has_validation_col=feature_prop.has_validation_col,
)
if dvalid is not None:
dval = [(dtrain, "training"), (dvalid, "validation")]