Merge branch 'master' into sync-condition-2023Oct11

This commit is contained in:
Hui Liu 2023-10-27 10:09:37 -07:00
commit 4302200a33
6 changed files with 283 additions and 52 deletions

View File

@ -22,6 +22,9 @@ struct CUDAContext;
struct DeviceSym { struct DeviceSym {
static auto constexpr CPU() { return "cpu"; } static auto constexpr CPU() { return "cpu"; }
static auto constexpr CUDA() { return "cuda"; } static auto constexpr CUDA() { return "cuda"; }
static auto constexpr SyclDefault() { return "sycl"; }
static auto constexpr SyclCPU() { return "sycl:cpu"; }
static auto constexpr SyclGPU() { return "sycl:gpu"; }
}; };
/** /**
@ -33,12 +36,19 @@ struct DeviceOrd {
static bst_d_ordinal_t constexpr CPUOrdinal() { return -1; } static bst_d_ordinal_t constexpr CPUOrdinal() { return -1; }
static bst_d_ordinal_t constexpr InvalidOrdinal() { return -2; } static bst_d_ordinal_t constexpr InvalidOrdinal() { return -2; }
enum Type : std::int16_t { kCPU = 0, kCUDA = 1 } device{kCPU}; enum Type : std::int16_t { kCPU = 0, kCUDA = 1,
// CUDA device ordinal. kSyclDefault = 2, kSyclCPU = 3, kSyclGPU = 4} device{kCPU};
// CUDA or Sycl device ordinal.
bst_d_ordinal_t ordinal{CPUOrdinal()}; bst_d_ordinal_t ordinal{CPUOrdinal()};
[[nodiscard]] bool IsCUDA() const { return device == kCUDA; } [[nodiscard]] bool IsCUDA() const { return device == kCUDA; }
[[nodiscard]] bool IsCPU() const { return device == kCPU; } [[nodiscard]] bool IsCPU() const { return device == kCPU; }
[[nodiscard]] bool IsSyclDefault() const { return device == kSyclDefault; }
[[nodiscard]] bool IsSyclCPU() const { return device == kSyclCPU; }
[[nodiscard]] bool IsSyclGPU() const { return device == kSyclGPU; }
[[nodiscard]] bool IsSycl() const { return (IsSyclDefault() ||
IsSyclCPU() ||
IsSyclGPU()); }
constexpr DeviceOrd() = default; constexpr DeviceOrd() = default;
constexpr DeviceOrd(Type type, bst_d_ordinal_t ord) : device{type}, ordinal{ord} {} constexpr DeviceOrd(Type type, bst_d_ordinal_t ord) : device{type}, ordinal{ord} {}
@ -60,6 +70,31 @@ struct DeviceOrd {
[[nodiscard]] static constexpr auto CUDA(bst_d_ordinal_t ordinal) { [[nodiscard]] static constexpr auto CUDA(bst_d_ordinal_t ordinal) {
return DeviceOrd{kCUDA, ordinal}; return DeviceOrd{kCUDA, ordinal};
} }
/**
* @brief Constructor for SYCL.
*
* @param ordinal SYCL device ordinal.
*/
[[nodiscard]] constexpr static auto SyclDefault(bst_d_ordinal_t ordinal = -1) {
return DeviceOrd{kSyclDefault, ordinal};
}
/**
* @brief Constructor for SYCL CPU.
*
* @param ordinal SYCL CPU device ordinal.
*/
[[nodiscard]] constexpr static auto SyclCPU(bst_d_ordinal_t ordinal = -1) {
return DeviceOrd{kSyclCPU, ordinal};
}
/**
* @brief Constructor for SYCL GPU.
*
* @param ordinal SYCL GPU device ordinal.
*/
[[nodiscard]] constexpr static auto SyclGPU(bst_d_ordinal_t ordinal = -1) {
return DeviceOrd{kSyclGPU, ordinal};
}
[[nodiscard]] bool operator==(DeviceOrd const& that) const { [[nodiscard]] bool operator==(DeviceOrd const& that) const {
return device == that.device && ordinal == that.ordinal; return device == that.device && ordinal == that.ordinal;
@ -74,6 +109,12 @@ struct DeviceOrd {
return DeviceSym::CPU(); return DeviceSym::CPU();
case DeviceOrd::kCUDA: case DeviceOrd::kCUDA:
return DeviceSym::CUDA() + (':' + std::to_string(ordinal)); return DeviceSym::CUDA() + (':' + std::to_string(ordinal));
case DeviceOrd::kSyclDefault:
return DeviceSym::SyclDefault() + (':' + std::to_string(ordinal));
case DeviceOrd::kSyclCPU:
return DeviceSym::SyclCPU() + (':' + std::to_string(ordinal));
case DeviceOrd::kSyclGPU:
return DeviceSym::SyclGPU() + (':' + std::to_string(ordinal));
default: { default: {
LOG(FATAL) << "Unknown device."; LOG(FATAL) << "Unknown device.";
return ""; return "";
@ -142,6 +183,25 @@ struct Context : public XGBoostParameter<Context> {
* @brief Is XGBoost running on a CUDA device? * @brief Is XGBoost running on a CUDA device?
*/ */
[[nodiscard]] bool IsCUDA() const { return Device().IsCUDA(); } [[nodiscard]] bool IsCUDA() const { return Device().IsCUDA(); }
/**
* @brief Is XGBoost running on the default SYCL device?
*/
[[nodiscard]] bool IsSyclDefault() const { return Device().IsSyclDefault(); }
/**
* @brief Is XGBoost running on a SYCL CPU?
*/
[[nodiscard]] bool IsSyclCPU() const { return Device().IsSyclCPU(); }
/**
* @brief Is XGBoost running on a SYCL GPU?
*/
[[nodiscard]] bool IsSyclGPU() const { return Device().IsSyclGPU(); }
/**
* @brief Is XGBoost running on any SYCL device?
*/
[[nodiscard]] bool IsSycl() const { return IsSyclDefault()
|| IsSyclCPU()
|| IsSyclGPU(); }
/** /**
* @brief Get the current device and ordinal. * @brief Get the current device and ordinal.
*/ */
@ -175,6 +235,7 @@ struct Context : public XGBoostParameter<Context> {
Context ctx = *this; Context ctx = *this;
return ctx.SetDevice(DeviceOrd::CPU()); return ctx.SetDevice(DeviceOrd::CPU());
} }
/** /**
* @brief Call function based on the current device. * @brief Call function based on the current device.
*/ */
@ -196,6 +257,20 @@ struct Context : public XGBoostParameter<Context> {
return std::invoke_result_t<CPUFn>(); return std::invoke_result_t<CPUFn>();
} }
/**
* @brief Call function for sycl devices
*/
template <typename CPUFn, typename CUDAFn, typename SYCLFn>
decltype(auto) DispatchDevice(CPUFn&& cpu_fn, CUDAFn&& cuda_fn, SYCLFn&& sycl_fn) const {
static_assert(std::is_same_v<std::invoke_result_t<CPUFn>, std::invoke_result_t<CUDAFn>>);
static_assert(std::is_same_v<std::invoke_result_t<CPUFn>, std::invoke_result_t<SYCLFn>>);
if (this->Device().IsSycl()) {
return sycl_fn();
} else {
return DispatchDevice(cpu_fn, cuda_fn);
}
}
// declare parameters // declare parameters
DMLC_DECLARE_PARAMETER(Context) { DMLC_DECLARE_PARAMETER(Context) {
DMLC_DECLARE_FIELD(seed) DMLC_DECLARE_FIELD(seed)

View File

@ -347,7 +347,7 @@ class _SparkXGBParams(
def _validate_gpu_params(self) -> None: def _validate_gpu_params(self) -> None:
"""Validate the gpu parameters and gpu configurations""" """Validate the gpu parameters and gpu configurations"""
if use_cuda(self.getOrDefault(self.device)) or self.getOrDefault(self.use_gpu): if self._run_on_gpu():
ss = _get_spark_session() ss = _get_spark_session()
sc = ss.sparkContext sc = ss.sparkContext
@ -414,9 +414,7 @@ class _SparkXGBParams(
) )
if self.getOrDefault(self.features_cols): if self.getOrDefault(self.features_cols):
if not use_cuda(self.getOrDefault(self.device)) and not self.getOrDefault( if not self._run_on_gpu():
self.use_gpu
):
raise ValueError( raise ValueError(
"features_col param with list value requires `device=cuda`." "features_col param with list value requires `device=cuda`."
) )
@ -473,6 +471,15 @@ class _SparkXGBParams(
self._validate_gpu_params() self._validate_gpu_params()
def _run_on_gpu(self) -> bool:
"""If train or transform on the gpu according to the parameters"""
return (
use_cuda(self.getOrDefault(self.device))
or self.getOrDefault(self.use_gpu)
or self.getOrDefault(self.getParam("tree_method")) == "gpu_hist"
)
def _validate_and_convert_feature_col_as_float_col_list( def _validate_and_convert_feature_col_as_float_col_list(
dataset: DataFrame, features_col_names: List[str] dataset: DataFrame, features_col_names: List[str]
@ -905,7 +912,7 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
"""Check if stage-level scheduling is not needed, """Check if stage-level scheduling is not needed,
return true to skip stage-level scheduling""" return true to skip stage-level scheduling"""
if use_cuda(self.getOrDefault(self.device)) or self.getOrDefault(self.use_gpu): if self._run_on_gpu():
ss = _get_spark_session() ss = _get_spark_session()
sc = ss.sparkContext sc = ss.sparkContext
@ -1022,9 +1029,8 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
dmatrix_kwargs, dmatrix_kwargs,
) = self._get_xgb_parameters(dataset) ) = self._get_xgb_parameters(dataset)
run_on_gpu = use_cuda(self.getOrDefault(self.device)) or self.getOrDefault( run_on_gpu = self._run_on_gpu()
self.use_gpu
)
is_local = _is_local(_get_spark_session().sparkContext) is_local = _is_local(_get_spark_session().sparkContext)
num_workers = self.getOrDefault(self.num_workers) num_workers = self.getOrDefault(self.num_workers)
@ -1318,12 +1324,15 @@ class _SparkXGBModel(Model, _SparkXGBParams, MLReadable, MLWritable):
dataset = dataset.drop(pred_struct_col) dataset = dataset.drop(pred_struct_col)
return dataset return dataset
def _gpu_transform(self) -> bool: def _run_on_gpu(self) -> bool:
"""If gpu is used to do the prediction, true to gpu prediction""" """If gpu is used to do the prediction according to the parameters
and spark configurations"""
use_gpu_by_params = super()._run_on_gpu()
if _is_local(_get_spark_session().sparkContext): if _is_local(_get_spark_session().sparkContext):
# if it's local model, we just use the internal "device" # if it's local model, no need to check the spark configurations
return use_cuda(self.getOrDefault(self.device)) return use_gpu_by_params
gpu_per_task = ( gpu_per_task = (
_get_spark_session() _get_spark_session()
@ -1333,15 +1342,15 @@ class _SparkXGBModel(Model, _SparkXGBParams, MLReadable, MLWritable):
# User don't set gpu configurations, just use cpu # User don't set gpu configurations, just use cpu
if gpu_per_task is None: if gpu_per_task is None:
if use_cuda(self.getOrDefault(self.device)): if use_gpu_by_params:
get_logger("XGBoost-PySpark").warning( get_logger("XGBoost-PySpark").warning(
"Do the prediction on the CPUs since " "Do the prediction on the CPUs since "
"no gpu configurations are set" "no gpu configurations are set"
) )
return False return False
# User already sets the gpu configurations, we just use the internal "device". # User already sets the gpu configurations.
return use_cuda(self.getOrDefault(self.device)) return use_gpu_by_params
def _transform(self, dataset: DataFrame) -> DataFrame: def _transform(self, dataset: DataFrame) -> DataFrame:
# pylint: disable=too-many-statements, too-many-locals # pylint: disable=too-many-statements, too-many-locals
@ -1367,7 +1376,7 @@ class _SparkXGBModel(Model, _SparkXGBParams, MLReadable, MLWritable):
_, schema = self._out_schema() _, schema = self._out_schema()
is_local = _is_local(_get_spark_session().sparkContext) is_local = _is_local(_get_spark_session().sparkContext)
run_on_gpu = self._gpu_transform() run_on_gpu = self._run_on_gpu()
@pandas_udf(schema) # type: ignore @pandas_udf(schema) # type: ignore
def predict_udf(iterator: Iterator[pd.DataFrame]) -> Iterator[pd.Series]: def predict_udf(iterator: Iterator[pd.DataFrame]) -> Iterator[pd.Series]:
@ -1381,9 +1390,10 @@ class _SparkXGBModel(Model, _SparkXGBParams, MLReadable, MLWritable):
dev_ordinal = -1 dev_ordinal = -1
if is_cudf_available(): msg = "Do the inference on the CPUs"
if run_on_gpu:
if is_cudf_available() and is_cupy_available():
if is_local: if is_local:
if run_on_gpu and is_cupy_available():
import cupy as cp # pylint: disable=import-error import cupy as cp # pylint: disable=import-error
total_gpus = cp.cuda.runtime.getDeviceCount() total_gpus = cp.cuda.runtime.getDeviceCount()
@ -1392,23 +1402,18 @@ class _SparkXGBModel(Model, _SparkXGBParams, MLReadable, MLWritable):
# For transform local mode, default the dev_ordinal to # For transform local mode, default the dev_ordinal to
# (partition id) % gpus. # (partition id) % gpus.
dev_ordinal = partition_id % total_gpus dev_ordinal = partition_id % total_gpus
elif run_on_gpu: else:
dev_ordinal = _get_gpu_id(context) dev_ordinal = _get_gpu_id(context)
if dev_ordinal >= 0: if dev_ordinal >= 0:
device = "cuda:" + str(dev_ordinal) device = "cuda:" + str(dev_ordinal)
get_logger("XGBoost-PySpark").info( msg = "Do the inference with device: " + device
"Do the inference with device: %s", device
)
model.set_params(device=device) model.set_params(device=device)
else: else:
get_logger("XGBoost-PySpark").info("Do the inference on the CPUs") msg = "Couldn't get the correct gpu id, fallback the inference on the CPUs"
else: else:
msg = ( msg = "CUDF or Cupy is unavailable, fallback the inference on the CPUs"
"CUDF is unavailable, fallback the inference on the CPUs"
if run_on_gpu
else "Do the inference on the CPUs"
)
get_logger("XGBoost-PySpark").info(msg) get_logger("XGBoost-PySpark").info(msg)
def to_gpu_if_possible(data: ArrayLike) -> ArrayLike: def to_gpu_if_possible(data: ArrayLike) -> ArrayLike:

View File

@ -104,19 +104,44 @@ DeviceOrd CUDAOrdinal(DeviceOrd device, bool) {
// mingw hangs on regex using rtools 430. Basic checks only. // mingw hangs on regex using rtools 430. Basic checks only.
CHECK_GE(input.size(), 3) << msg; CHECK_GE(input.size(), 3) << msg;
auto substr = input.substr(0, 3); auto substr = input.substr(0, 3);
bool valid = substr == "cpu" || substr == "cud" || substr == "gpu"; bool valid = substr == "cpu" || substr == "cud" || substr == "gpu" || substr == "syc";
CHECK(valid) << msg; CHECK(valid) << msg;
#else #else
std::regex pattern{"gpu(:[0-9]+)?|cuda(:[0-9]+)?|cpu"}; std::regex pattern{"gpu(:[0-9]+)?|cuda(:[0-9]+)?|cpu|sycl(:cpu|:gpu)?(:-1|:[0-9]+)?"};
if (!std::regex_match(input, pattern)) { if (!std::regex_match(input, pattern)) {
fatal(); fatal();
} }
#endif // defined(__MINGW32__) #endif // defined(__MINGW32__)
// handle alias // handle alias
std::string s_device = std::regex_replace(input, std::regex{"gpu"}, DeviceSym::CUDA()); #if defined(__MINGW32__)
// mingw hangs on regex using rtools 430. Basic checks only.
bool is_sycl = (substr == "syc");
#else
bool is_sycl = std::regex_match(input, std::regex("sycl(:cpu|:gpu)?(:-1|:[0-9]+)?"));
#endif // defined(__MINGW32__)
std::string s_device = input;
if (!is_sycl) {
s_device = std::regex_replace(s_device, std::regex{"gpu"}, DeviceSym::CUDA());
}
auto split_it = std::find(s_device.cbegin(), s_device.cend(), ':'); auto split_it = std::find(s_device.cbegin(), s_device.cend(), ':');
// For these cases we need to move iterator to the end, not to look for a ordinal.
if ((s_device == "sycl:cpu") ||
(s_device == "sycl:gpu")) {
split_it = s_device.cend();
}
// For s_device like "sycl:gpu:1"
if (split_it != s_device.cend()) {
auto second_split_it = std::find(split_it + 1, s_device.cend(), ':');
if (second_split_it != s_device.cend()) {
split_it = second_split_it;
}
}
DeviceOrd device; DeviceOrd device;
device.ordinal = DeviceOrd::InvalidOrdinal(); // mark it invalid for check. device.ordinal = DeviceOrd::InvalidOrdinal(); // mark it invalid for check.
if (split_it == s_device.cend()) { if (split_it == s_device.cend()) {
@ -125,15 +150,22 @@ DeviceOrd CUDAOrdinal(DeviceOrd device, bool) {
device = DeviceOrd::CPU(); device = DeviceOrd::CPU();
} else if (s_device == DeviceSym::CUDA()) { } else if (s_device == DeviceSym::CUDA()) {
device = DeviceOrd::CUDA(0); // use 0 as default; device = DeviceOrd::CUDA(0); // use 0 as default;
} else if (s_device == DeviceSym::SyclDefault()) {
device = DeviceOrd::SyclDefault();
} else if (s_device == DeviceSym::SyclCPU()) {
device = DeviceOrd::SyclCPU();
} else if (s_device == DeviceSym::SyclGPU()) {
device = DeviceOrd::SyclGPU();
} else { } else {
fatal(); fatal();
} }
} else { } else {
// must be CUDA when ordinal is specifed. // must be CUDA or SYCL when ordinal is specifed.
// +1 for colon // +1 for colon
std::size_t offset = std::distance(s_device.cbegin(), split_it) + 1; std::size_t offset = std::distance(s_device.cbegin(), split_it) + 1;
// substr // substr
StringView s_ordinal = {s_device.data() + offset, s_device.size() - offset}; StringView s_ordinal = {s_device.data() + offset, s_device.size() - offset};
StringView s_type = {s_device.data(), offset - 1};
if (s_ordinal.empty()) { if (s_ordinal.empty()) {
fatal(); fatal();
} }
@ -143,13 +175,23 @@ DeviceOrd CUDAOrdinal(DeviceOrd device, bool) {
} }
CHECK_LE(opt_id.value(), std::numeric_limits<bst_d_ordinal_t>::max()) CHECK_LE(opt_id.value(), std::numeric_limits<bst_d_ordinal_t>::max())
<< "Ordinal value too large."; << "Ordinal value too large.";
if (s_type == DeviceSym::SyclDefault()) {
device = DeviceOrd::SyclDefault(opt_id.value());
} else if (s_type == DeviceSym::SyclCPU()) {
device = DeviceOrd::SyclCPU(opt_id.value());
} else if (s_type == DeviceSym::SyclGPU()) {
device = DeviceOrd::SyclGPU(opt_id.value());
} else {
device = DeviceOrd::CUDA(opt_id.value()); device = DeviceOrd::CUDA(opt_id.value());
} }
}
if (device.ordinal < DeviceOrd::CPUOrdinal()) { if (device.ordinal < DeviceOrd::CPUOrdinal()) {
fatal(); fatal();
} }
if (device.IsCUDA()) {
device = CUDAOrdinal(device, fail_on_invalid_gpu_id); device = CUDAOrdinal(device, fail_on_invalid_gpu_id);
}
return device; return device;
} }
@ -216,7 +258,7 @@ void Context::SetDeviceOrdinal(Args const& kwargs) {
if (this->IsCPU()) { if (this->IsCPU()) {
CHECK_EQ(this->device_.ordinal, DeviceOrd::CPUOrdinal()); CHECK_EQ(this->device_.ordinal, DeviceOrd::CPUOrdinal());
} else { } else if (this->IsCUDA()) {
CHECK_GT(this->device_.ordinal, DeviceOrd::CPUOrdinal()); CHECK_GT(this->device_.ordinal, DeviceOrd::CPUOrdinal());
} }
} }

View File

@ -45,4 +45,97 @@ TEST(Context, ErrorInit) {
ASSERT_NE(msg.find("foo"), std::string::npos); ASSERT_NE(msg.find("foo"), std::string::npos);
} }
} }
TEST(Context, SYCL) {
Context ctx;
// Default SYCL device
{
ctx.UpdateAllowUnknown(Args{{"device", "sycl"}});
ASSERT_EQ(ctx.Device(), DeviceOrd::SyclDefault());
ASSERT_EQ(ctx.Ordinal(), -1);
std::int32_t flag{0};
ctx.DispatchDevice([&] { flag = -1; }, [&] { flag = 1; }, [&] { flag = 2; });
ASSERT_EQ(flag, 2);
std::stringstream ss;
ss << ctx.Device();
ASSERT_EQ(ss.str(), "sycl:-1");
}
// SYCL device with idx
{
ctx.UpdateAllowUnknown(Args{{"device", "sycl:42"}});
ASSERT_EQ(ctx.Device(), DeviceOrd::SyclDefault(42));
ASSERT_EQ(ctx.Ordinal(), 42);
std::int32_t flag{0};
ctx.DispatchDevice([&] { flag = -1; }, [&] { flag = 1; }, [&] { flag = 2; });
ASSERT_EQ(flag, 2);
std::stringstream ss;
ss << ctx.Device();
ASSERT_EQ(ss.str(), "sycl:42");
}
// SYCL cpu
{
ctx.UpdateAllowUnknown(Args{{"device", "sycl:cpu"}});
ASSERT_EQ(ctx.Device(), DeviceOrd::SyclCPU());
ASSERT_EQ(ctx.Ordinal(), -1);
std::int32_t flag{0};
ctx.DispatchDevice([&] { flag = -1; }, [&] { flag = 1; }, [&] { flag = 2; });
ASSERT_EQ(flag, 2);
std::stringstream ss;
ss << ctx.Device();
ASSERT_EQ(ss.str(), "sycl:cpu:-1");
}
// SYCL cpu with idx
{
ctx.UpdateAllowUnknown(Args{{"device", "sycl:cpu:42"}});
ASSERT_EQ(ctx.Device(), DeviceOrd::SyclCPU(42));
ASSERT_EQ(ctx.Ordinal(), 42);
std::int32_t flag{0};
ctx.DispatchDevice([&] { flag = -1; }, [&] { flag = 1; }, [&] { flag = 2; });
ASSERT_EQ(flag, 2);
std::stringstream ss;
ss << ctx.Device();
ASSERT_EQ(ss.str(), "sycl:cpu:42");
}
// SYCL gpu
{
ctx.UpdateAllowUnknown(Args{{"device", "sycl:gpu"}});
ASSERT_EQ(ctx.Device(), DeviceOrd::SyclGPU());
ASSERT_EQ(ctx.Ordinal(), -1);
std::int32_t flag{0};
ctx.DispatchDevice([&] { flag = -1; }, [&] { flag = 1; }, [&] { flag = 2; });
ASSERT_EQ(flag, 2);
std::stringstream ss;
ss << ctx.Device();
ASSERT_EQ(ss.str(), "sycl:gpu:-1");
}
// SYCL gpu with idx
{
ctx.UpdateAllowUnknown(Args{{"device", "sycl:gpu:42"}});
ASSERT_EQ(ctx.Device(), DeviceOrd::SyclGPU(42));
ASSERT_EQ(ctx.Ordinal(), 42);
std::int32_t flag{0};
ctx.DispatchDevice([&] { flag = -1; }, [&] { flag = 1; }, [&] { flag = 2; });
ASSERT_EQ(flag, 2);
std::stringstream ss;
ss << ctx.Device();
ASSERT_EQ(ss.str(), "sycl:gpu:42");
}
}
} // namespace xgboost } // namespace xgboost

View File

@ -251,10 +251,10 @@ def test_gpu_transform(spark_diabetes_dataset) -> None:
model: SparkXGBRegressorModel = regressor.fit(train_df) model: SparkXGBRegressorModel = regressor.fit(train_df)
# The model trained with GPUs, and transform with GPU configurations. # The model trained with GPUs, and transform with GPU configurations.
assert model._gpu_transform() assert model._run_on_gpu()
model.set_device("cpu") model.set_device("cpu")
assert not model._gpu_transform() assert not model._run_on_gpu()
# without error # without error
cpu_rows = model.transform(test_df).select("prediction").collect() cpu_rows = model.transform(test_df).select("prediction").collect()
@ -263,11 +263,11 @@ def test_gpu_transform(spark_diabetes_dataset) -> None:
# The model trained with CPUs. Even with GPU configurations, # The model trained with CPUs. Even with GPU configurations,
# still prefer transforming with CPUs # still prefer transforming with CPUs
assert not model._gpu_transform() assert not model._run_on_gpu()
# Set gpu transform explicitly. # Set gpu transform explicitly.
model.set_device("cuda") model.set_device("cuda")
assert model._gpu_transform() assert model._run_on_gpu()
# without error # without error
gpu_rows = model.transform(test_df).select("prediction").collect() gpu_rows = model.transform(test_df).select("prediction").collect()

View File

@ -888,6 +888,22 @@ class TestPySparkLocal:
clf = SparkXGBClassifier(device="cuda") clf = SparkXGBClassifier(device="cuda")
clf._validate_params() clf._validate_params()
def test_gpu_params(self) -> None:
clf = SparkXGBClassifier()
assert not clf._run_on_gpu()
clf = SparkXGBClassifier(device="cuda", tree_method="hist")
assert clf._run_on_gpu()
clf = SparkXGBClassifier(device="cuda")
assert clf._run_on_gpu()
clf = SparkXGBClassifier(tree_method="gpu_hist")
assert clf._run_on_gpu()
clf = SparkXGBClassifier(use_gpu=True)
assert clf._run_on_gpu()
def test_gpu_transform(self, clf_data: ClfData) -> None: def test_gpu_transform(self, clf_data: ClfData) -> None:
"""local mode""" """local mode"""
classifier = SparkXGBClassifier(device="cpu") classifier = SparkXGBClassifier(device="cpu")
@ -898,23 +914,23 @@ class TestPySparkLocal:
model.write().overwrite().save(path) model.write().overwrite().save(path)
# The model trained with CPU, transform defaults to cpu # The model trained with CPU, transform defaults to cpu
assert not model._gpu_transform() assert not model._run_on_gpu()
# without error # without error
model.transform(clf_data.cls_df_test).collect() model.transform(clf_data.cls_df_test).collect()
model.set_device("cuda") model.set_device("cuda")
assert model._gpu_transform() assert model._run_on_gpu()
model_loaded = SparkXGBClassifierModel.load(path) model_loaded = SparkXGBClassifierModel.load(path)
# The model trained with CPU, transform defaults to cpu # The model trained with CPU, transform defaults to cpu
assert not model_loaded._gpu_transform() assert not model_loaded._run_on_gpu()
# without error # without error
model_loaded.transform(clf_data.cls_df_test).collect() model_loaded.transform(clf_data.cls_df_test).collect()
model_loaded.set_device("cuda") model_loaded.set_device("cuda")
assert model_loaded._gpu_transform() assert model_loaded._run_on_gpu()
class XgboostLocalTest(SparkTestCase): class XgboostLocalTest(SparkTestCase):