* Address windows compilation error * Do not allow divide by zero in weight calculation * Update tests
This commit is contained in:
parent
516457fadc
commit
f00fd87b36
@ -101,7 +101,8 @@ class SoftmaxMultiClassObj : public ObjFunction {
|
|||||||
for (int k = 0; k < nclass; ++k) {
|
for (int k = 0; k < nclass; ++k) {
|
||||||
// Computation duplicated to avoid creating a cache.
|
// Computation duplicated to avoid creating a cache.
|
||||||
bst_float p = expf(point[k] - wmax) / static_cast<float>(wsum);
|
bst_float p = expf(point[k] - wmax) / static_cast<float>(wsum);
|
||||||
const bst_float h = fmax(2.0f * p * (1.0f - p) * wt, kRtEps);
|
const float eps = 1e-16f;
|
||||||
|
const bst_float h = fmax(2.0f * p * (1.0f - p) * wt, eps);
|
||||||
p = label == k ? p - 1.0f : p;
|
p = label == k ? p - 1.0f : p;
|
||||||
gpair[idx * nclass + k] = GradientPair(p * wt, h);
|
gpair[idx * nclass + k] = GradientPair(p * wt, h);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -292,7 +292,7 @@ XGBOOST_DEVICE inline T CalcGain(const TrainingParams &p, T sum_grad, T sum_hess
|
|||||||
template <typename TrainingParams, typename T>
|
template <typename TrainingParams, typename T>
|
||||||
XGBOOST_DEVICE inline T CalcWeight(const TrainingParams &p, T sum_grad,
|
XGBOOST_DEVICE inline T CalcWeight(const TrainingParams &p, T sum_grad,
|
||||||
T sum_hess) {
|
T sum_hess) {
|
||||||
if (sum_hess < p.min_child_weight) {
|
if (sum_hess < p.min_child_weight || sum_hess <= 0.0) {
|
||||||
return 0.0;
|
return 0.0;
|
||||||
}
|
}
|
||||||
T dw;
|
T dw;
|
||||||
|
|||||||
@ -28,8 +28,8 @@ class TestGPU(unittest.TestCase):
|
|||||||
assert_gpu_results(cpu_results, gpu_results)
|
assert_gpu_results(cpu_results, gpu_results)
|
||||||
|
|
||||||
def test_gpu_hist(self):
|
def test_gpu_hist(self):
|
||||||
variable_param = {'n_gpus': [-1], 'max_depth': [2, 10], 'max_leaves': [255, 4],
|
variable_param = {'n_gpus': [-1], 'max_depth': [2, 8], 'max_leaves': [255, 4],
|
||||||
'max_bin': [2, 256],
|
'max_bin': [2, 256], 'min_child_weight': [0, 1], 'lambda': [0.0, 1.0],
|
||||||
'grow_policy': ['lossguide']}
|
'grow_policy': ['lossguide']}
|
||||||
for param in parameter_combinations(variable_param):
|
for param in parameter_combinations(variable_param):
|
||||||
param['tree_method'] = 'gpu_hist'
|
param['tree_method'] = 'gpu_hist'
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user