[PYTHON] Refactor trainnig API to use callback

This commit is contained in:
tqchen
2016-05-19 17:47:11 -07:00
parent 03996dd4e8
commit 149589c583
18 changed files with 492 additions and 278 deletions

View File

@@ -35,6 +35,22 @@ class TestBasic(unittest.TestCase):
# assert they are the same
assert np.sum(np.abs(preds2 - preds)) == 0
def test_record_results(self):
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
param = {'max_depth': 2, 'eta': 1, 'silent': 1, 'objective': 'binary:logistic'}
# specify validations set to watch performance
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
num_round = 2
result = {}
res2 = {}
xgb.train(param, dtrain, num_round, watchlist,
callbacks=[xgb.callback.record_evaluation(result)])
xgb.train(param, dtrain, num_round, watchlist,
evals_result=res2)
assert result['train']['error'][0] < 0.1
assert res2 == result
def test_multiclass(self):
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
@@ -189,5 +205,5 @@ class TestBasic(unittest.TestCase):
# return np.ndarray
cv = xgb.cv(params, dm, num_boost_round=10, nfold=10, as_pandas=False)
assert isinstance(cv, np.ndarray)
assert cv.shape == (10, 4)
assert isinstance(cv, dict)
assert len(cv) == (4)

View File

@@ -1,5 +1,5 @@
import xgboost as xgb
import xgboost.testing as tm
import testing as tm
import numpy as np
import unittest

View File

@@ -1,5 +1,5 @@
import xgboost as xgb
import xgboost.testing as tm
import testing as tm
import numpy as np
import unittest

View File

@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
import numpy as np
import xgboost as xgb
import xgboost.testing as tm
import testing as tm
import unittest
try:

View File

@@ -1,5 +1,5 @@
import xgboost as xgb
import xgboost.testing as tm
import testing as tm
import numpy as np
import unittest

View File

@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
import numpy as np
import xgboost as xgb
import xgboost.testing as tm
import testing as tm
import unittest
try:

View File

@@ -1,7 +1,7 @@
import numpy as np
import random
import xgboost as xgb
import xgboost.testing as tm
import testing as tm
rng = np.random.RandomState(1994)

22
tests/python/testing.py Normal file
View File

@@ -0,0 +1,22 @@
# coding: utf-8
import nose
from xgboost.compat import SKLEARN_INSTALLED, PANDAS_INSTALLED
def _skip_if_no_sklearn():
if not SKLEARN_INSTALLED:
raise nose.SkipTest()
def _skip_if_no_pandas():
if not PANDAS_INSTALLED:
raise nose.SkipTest()
def _skip_if_no_matplotlib():
try:
import matplotlib.pyplot as _ # noqa
except ImportError:
raise nose.SkipTest()