[pyspark] Re-work _fit function (#8630)
This commit is contained in:
parent
beefd28471
commit
d3ad0524e7
@ -3,7 +3,8 @@
|
||||
# pylint: disable=fixme, too-many-ancestors, protected-access, no-member, invalid-name
|
||||
# pylint: disable=too-few-public-methods, too-many-lines, too-many-branches
|
||||
import json
|
||||
from typing import Iterator, Optional, Tuple
|
||||
from collections import namedtuple
|
||||
from typing import Iterator, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
@ -21,6 +22,7 @@ from pyspark.ml.param.shared import (
|
||||
HasWeightCol,
|
||||
)
|
||||
from pyspark.ml.util import MLReadable, MLWritable
|
||||
from pyspark.sql import DataFrame
|
||||
from pyspark.sql.functions import col, countDistinct, pandas_udf, rand, struct
|
||||
from pyspark.sql.types import (
|
||||
ArrayType,
|
||||
@ -471,6 +473,12 @@ def _get_unwrapped_vec_cols(feature_col):
|
||||
]
|
||||
|
||||
|
||||
FeatureProp = namedtuple(
|
||||
"FeatureProp",
|
||||
("enable_sparse_data_optim", "has_validation_col", "features_cols_names"),
|
||||
)
|
||||
|
||||
|
||||
class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@ -641,9 +649,9 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
||||
}
|
||||
return booster_params, kwargs_params
|
||||
|
||||
def _fit(self, dataset):
|
||||
# pylint: disable=too-many-statements, too-many-locals
|
||||
self._validate_params()
|
||||
def _prepare_input_columns_and_feature_prop(
|
||||
self, dataset: DataFrame
|
||||
) -> Tuple[List[str], FeatureProp]:
|
||||
label_col = col(self.getOrDefault(self.labelCol)).alias(alias.label)
|
||||
|
||||
select_cols = [label_col]
|
||||
@ -698,6 +706,18 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
||||
if self.isDefined(self.qid_col) and self.getOrDefault(self.qid_col):
|
||||
select_cols.append(col(self.getOrDefault(self.qid_col)).alias(alias.qid))
|
||||
|
||||
feature_prop = FeatureProp(
|
||||
enable_sparse_data_optim, has_validation_col, features_cols_names
|
||||
)
|
||||
return select_cols, feature_prop
|
||||
|
||||
def _prepare_input(self, dataset: DataFrame) -> Tuple[DataFrame, FeatureProp]:
|
||||
"""Prepare the input including column pruning, repartition and so on"""
|
||||
|
||||
select_cols, feature_prop = self._prepare_input_columns_and_feature_prop(
|
||||
dataset
|
||||
)
|
||||
|
||||
dataset = dataset.select(*select_cols)
|
||||
|
||||
num_workers = self.getOrDefault(self.num_workers)
|
||||
@ -732,11 +752,13 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
||||
# XGBoost requires qid to be sorted for each partition
|
||||
dataset = dataset.sortWithinPartitions(alias.qid, ascending=True)
|
||||
|
||||
return dataset, feature_prop
|
||||
|
||||
def _get_xgb_parameters(self, dataset: DataFrame):
|
||||
train_params = self._get_distributed_train_params(dataset)
|
||||
booster_params, train_call_kwargs_params = self._get_xgb_train_call_args(
|
||||
train_params
|
||||
)
|
||||
|
||||
cpu_per_task = int(
|
||||
_get_spark_session().sparkContext.getConf().get("spark.task.cpus", "1")
|
||||
)
|
||||
@ -749,9 +771,6 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
||||
"missing": float(self.getOrDefault(self.missing)),
|
||||
}
|
||||
booster_params["nthread"] = cpu_per_task
|
||||
use_gpu = self.getOrDefault(self.use_gpu)
|
||||
|
||||
is_local = _is_local(_get_spark_session().sparkContext)
|
||||
|
||||
# Remove the parameters whose value is None
|
||||
booster_params = {k: v for k, v in booster_params.items() if v is not None}
|
||||
@ -760,7 +779,25 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
||||
}
|
||||
dmatrix_kwargs = {k: v for k, v in dmatrix_kwargs.items() if v is not None}
|
||||
|
||||
use_hist = booster_params.get("tree_method", None) in ("hist", "gpu_hist")
|
||||
return booster_params, train_call_kwargs_params, dmatrix_kwargs
|
||||
|
||||
def _fit(self, dataset):
|
||||
# pylint: disable=too-many-statements, too-many-locals
|
||||
self._validate_params()
|
||||
|
||||
dataset, feature_prop = self._prepare_input(dataset)
|
||||
|
||||
(
|
||||
booster_params,
|
||||
train_call_kwargs_params,
|
||||
dmatrix_kwargs,
|
||||
) = self._get_xgb_parameters(dataset)
|
||||
|
||||
use_gpu = self.getOrDefault(self.use_gpu)
|
||||
|
||||
is_local = _is_local(_get_spark_session().sparkContext)
|
||||
|
||||
num_workers = self.getOrDefault(self.num_workers)
|
||||
|
||||
def _train_booster(pandas_df_iter):
|
||||
"""Takes in an RDD partition and outputs a booster for that partition after
|
||||
@ -772,6 +809,8 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
||||
context = BarrierTaskContext.get()
|
||||
|
||||
gpu_id = None
|
||||
use_hist = booster_params.get("tree_method", None) in ("hist", "gpu_hist")
|
||||
|
||||
if use_gpu:
|
||||
gpu_id = context.partitionId() if is_local else _get_gpu_id(context)
|
||||
booster_params["gpu_id"] = gpu_id
|
||||
@ -814,12 +853,12 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
||||
with CommunicatorContext(context, **_rabit_args):
|
||||
dtrain, dvalid = create_dmatrix_from_partitions(
|
||||
pandas_df_iter,
|
||||
features_cols_names,
|
||||
feature_prop.features_cols_names,
|
||||
gpu_id,
|
||||
use_qdm,
|
||||
dmatrix_kwargs,
|
||||
enable_sparse_data_optim=enable_sparse_data_optim,
|
||||
has_validation_col=has_validation_col,
|
||||
enable_sparse_data_optim=feature_prop.enable_sparse_data_optim,
|
||||
has_validation_col=feature_prop.has_validation_col,
|
||||
)
|
||||
if dvalid is not None:
|
||||
dval = [(dtrain, "training"), (dvalid, "validation")]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user