parent
d231e7c35f
commit
357a78b3de
@ -575,6 +575,20 @@ void GPUDartPredictInc(common::Span<float> out_predts,
|
|||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
void GPUDartInplacePredictInc(common::Span<float> out_predts,
|
||||||
|
common::Span<float> predts, float tree_w,
|
||||||
|
size_t n_rows, float base_score,
|
||||||
|
bst_group_t n_groups,
|
||||||
|
bst_group_t group)
|
||||||
|
#if defined(XGBOOST_USE_CUDA)
|
||||||
|
; // NOLINT
|
||||||
|
#else
|
||||||
|
{
|
||||||
|
common::AssertGPUSupport();
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
|
||||||
class Dart : public GBTree {
|
class Dart : public GBTree {
|
||||||
public:
|
public:
|
||||||
explicit Dart(LearnerModelParam const* booster_config) :
|
explicit Dart(LearnerModelParam const* booster_config) :
|
||||||
@ -728,13 +742,14 @@ class Dart : public GBTree {
|
|||||||
gpu_predictor_.get()
|
gpu_predictor_.get()
|
||||||
#endif // defined(XGBOOST_USE_CUDA)
|
#endif // defined(XGBOOST_USE_CUDA)
|
||||||
};
|
};
|
||||||
|
Predictor const * predictor {nullptr};
|
||||||
|
|
||||||
MetaInfo info;
|
MetaInfo info;
|
||||||
StringView msg{"Unsupported data type for inplace predict."};
|
StringView msg{"Unsupported data type for inplace predict."};
|
||||||
int32_t device = GenericParameter::kCpuId;
|
int32_t device = GenericParameter::kCpuId;
|
||||||
|
PredictionCacheEntry predts;
|
||||||
// Inplace predict is not used for training, so no need to drop tree.
|
// Inplace predict is not used for training, so no need to drop tree.
|
||||||
for (size_t i = tree_begin; i < tree_end; ++i) {
|
for (size_t i = tree_begin; i < tree_end; ++i) {
|
||||||
PredictionCacheEntry predts;
|
|
||||||
if (tparam_.predictor == PredictorType::kAuto) {
|
if (tparam_.predictor == PredictorType::kAuto) {
|
||||||
// Try both predictor implementations
|
// Try both predictor implementations
|
||||||
bool success = false;
|
bool success = false;
|
||||||
@ -742,6 +757,7 @@ class Dart : public GBTree {
|
|||||||
if (p && p->InplacePredict(x, nullptr, model_, missing, &predts, i,
|
if (p && p->InplacePredict(x, nullptr, model_, missing, &predts, i,
|
||||||
i + 1)) {
|
i + 1)) {
|
||||||
success = true;
|
success = true;
|
||||||
|
predictor = p;
|
||||||
#if defined(XGBOOST_USE_CUDA)
|
#if defined(XGBOOST_USE_CUDA)
|
||||||
device = predts.predictions.DeviceIdx();
|
device = predts.predictions.DeviceIdx();
|
||||||
#endif // defined(XGBOOST_USE_CUDA)
|
#endif // defined(XGBOOST_USE_CUDA)
|
||||||
@ -750,46 +766,53 @@ class Dart : public GBTree {
|
|||||||
}
|
}
|
||||||
CHECK(success) << msg;
|
CHECK(success) << msg;
|
||||||
} else {
|
} else {
|
||||||
// No base margin for each tree
|
// No base margin from meta info for each tree
|
||||||
bool success = this->GetPredictor()->InplacePredict(
|
predictor = this->GetPredictor().get();
|
||||||
x, nullptr, model_, missing, &predts, i, i + 1);
|
bool success = predictor->InplacePredict(x, nullptr, model_, missing,
|
||||||
|
&predts, i, i + 1);
|
||||||
device = predts.predictions.DeviceIdx();
|
device = predts.predictions.DeviceIdx();
|
||||||
CHECK(success) << msg;
|
CHECK(success) << msg;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto w = this->weight_drop_.at(i);
|
auto w = this->weight_drop_.at(i);
|
||||||
auto &h_predts = predts.predictions.HostVector();
|
size_t n_groups = model_.learner_model_param->num_output_group;
|
||||||
auto &h_out_predts = out_preds->predictions.HostVector();
|
auto n_rows = predts.predictions.Size() / n_groups;
|
||||||
|
|
||||||
if (i == tree_begin) {
|
if (i == tree_begin) {
|
||||||
auto n_rows =
|
// base margin is added here.
|
||||||
h_predts.size() / model_.learner_model_param->num_output_group;
|
|
||||||
if (p_m) {
|
if (p_m) {
|
||||||
p_m->Info().num_row_ = n_rows;
|
p_m->Info().num_row_ = n_rows;
|
||||||
cpu_predictor_->InitOutPredictions(p_m->Info(),
|
predictor->InitOutPredictions(p_m->Info(), &out_preds->predictions,
|
||||||
&out_preds->predictions, model_);
|
model_);
|
||||||
} else {
|
} else {
|
||||||
info.num_row_ = n_rows;
|
info.num_row_ = n_rows;
|
||||||
cpu_predictor_->InitOutPredictions(info, &out_preds->predictions,
|
predictor->InitOutPredictions(info, &out_preds->predictions, model_);
|
||||||
model_);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Multiple the tree weight
|
// Multiple the tree weight
|
||||||
CHECK_EQ(h_predts.size(), h_out_predts.size());
|
CHECK_EQ(predts.predictions.Size(), out_preds->predictions.Size());
|
||||||
|
auto group = model_.tree_info.at(i);
|
||||||
|
|
||||||
|
if (device == GenericParameter::kCpuId) {
|
||||||
|
auto &h_predts = predts.predictions.HostVector();
|
||||||
|
auto &h_out_predts = out_preds->predictions.HostVector();
|
||||||
#pragma omp parallel for
|
#pragma omp parallel for
|
||||||
for (omp_ulong i = 0; i < h_out_predts.size(); ++i) {
|
for (omp_ulong ridx = 0; ridx < n_rows; ++ridx) {
|
||||||
// Need to remove the base margin from indiviual tree.
|
const size_t offset = ridx * n_groups + group;
|
||||||
h_out_predts[i] +=
|
// Need to remove the base margin from indiviual tree.
|
||||||
(h_predts[i] - model_.learner_model_param->base_score) * w;
|
h_out_predts[offset] +=
|
||||||
|
(h_predts[offset] - model_.learner_model_param->base_score) * w;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
out_preds->predictions.SetDevice(device);
|
||||||
|
predts.predictions.SetDevice(device);
|
||||||
|
GPUDartInplacePredictInc(out_preds->predictions.DeviceSpan(),
|
||||||
|
predts.predictions.DeviceSpan(), w, n_rows,
|
||||||
|
model_.learner_model_param->base_score,
|
||||||
|
n_groups, group);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (device != GenericParameter::kCpuId) {
|
|
||||||
out_preds->predictions.SetDevice(device);
|
|
||||||
out_preds->predictions.DeviceSpan();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void PredictInstance(const SparsePage::Inst &inst,
|
void PredictInstance(const SparsePage::Inst &inst,
|
||||||
|
|||||||
@ -14,5 +14,15 @@ void GPUDartPredictInc(common::Span<float> out_predts,
|
|||||||
out_predts[offset] += (predts[offset] * tree_w);
|
out_predts[offset] += (predts[offset] * tree_w);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void GPUDartInplacePredictInc(common::Span<float> out_predts,
|
||||||
|
common::Span<float> predts, float tree_w,
|
||||||
|
size_t n_rows, float base_score,
|
||||||
|
bst_group_t n_groups, bst_group_t group) {
|
||||||
|
dh::LaunchN(dh::CurrentDevice(), n_rows, [=] XGBOOST_DEVICE(size_t ridx) {
|
||||||
|
const size_t offset = ridx * n_groups + group;
|
||||||
|
out_predts[offset] += (predts[offset] - base_score) * tree_w;
|
||||||
|
});
|
||||||
|
}
|
||||||
} // namespace gbm
|
} // namespace gbm
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -332,27 +332,44 @@ class TestGPUPredict:
|
|||||||
rmse = mean_squared_error(y_true=y, y_pred=pred, squared=False)
|
rmse = mean_squared_error(y_true=y, y_pred=pred, squared=False)
|
||||||
np.testing.assert_almost_equal(rmse, eval_history['train']['rmse'][-1], decimal=5)
|
np.testing.assert_almost_equal(rmse, eval_history['train']['rmse'][-1], decimal=5)
|
||||||
|
|
||||||
def test_predict_dart(self):
|
@pytest.mark.parametrize("n_classes", [2, 3])
|
||||||
|
def test_predict_dart(self, n_classes):
|
||||||
|
from sklearn.datasets import make_classification
|
||||||
import cupy as cp
|
import cupy as cp
|
||||||
rng = cp.random.RandomState(1994)
|
|
||||||
n_samples = 1000
|
n_samples = 1000
|
||||||
X = rng.randn(n_samples, 10)
|
X_, y_ = make_classification(
|
||||||
y = rng.randn(n_samples)
|
n_samples=n_samples, n_informative=5, n_classes=n_classes
|
||||||
|
)
|
||||||
|
X, y = cp.array(X_), cp.array(y_)
|
||||||
|
|
||||||
Xy = xgb.DMatrix(X, y)
|
Xy = xgb.DMatrix(X, y)
|
||||||
booster = xgb.train(
|
if n_classes == 2:
|
||||||
{
|
params = {
|
||||||
"tree_method": "gpu_hist",
|
"tree_method": "gpu_hist",
|
||||||
"booster": "dart",
|
"booster": "dart",
|
||||||
"rate_drop": 0.5,
|
"rate_drop": 0.5,
|
||||||
},
|
"objective": "binary:logistic"
|
||||||
Xy,
|
}
|
||||||
num_boost_round=32
|
else:
|
||||||
)
|
params = {
|
||||||
|
"tree_method": "gpu_hist",
|
||||||
|
"booster": "dart",
|
||||||
|
"rate_drop": 0.5,
|
||||||
|
"objective": "multi:softprob",
|
||||||
|
"num_class": n_classes
|
||||||
|
}
|
||||||
|
|
||||||
|
booster = xgb.train(params, Xy, num_boost_round=32)
|
||||||
# predictor=auto
|
# predictor=auto
|
||||||
inplace = booster.inplace_predict(X)
|
inplace = booster.inplace_predict(X)
|
||||||
copied = booster.predict(Xy)
|
copied = booster.predict(Xy)
|
||||||
|
cpu_inplace = booster.inplace_predict(X_)
|
||||||
|
booster.set_param({"predictor": "cpu_predictor"})
|
||||||
|
cpu_copied = booster.predict(Xy)
|
||||||
|
|
||||||
copied = cp.array(copied)
|
copied = cp.array(copied)
|
||||||
|
cp.testing.assert_allclose(cpu_inplace, copied, atol=1e-6)
|
||||||
|
cp.testing.assert_allclose(cpu_copied, copied, atol=1e-6)
|
||||||
cp.testing.assert_allclose(inplace, copied, atol=1e-6)
|
cp.testing.assert_allclose(inplace, copied, atol=1e-6)
|
||||||
|
|
||||||
booster.set_param({"predictor": "gpu_predictor"})
|
booster.set_param({"predictor": "gpu_predictor"})
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user